mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * delete use_fast_ffn
This commit is contained in:
@@ -5,12 +5,6 @@ default_stages:
|
||||
- pre-commit # Run locally
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
@@ -29,15 +23,6 @@ repos:
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
|
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
@@ -69,7 +69,6 @@ class ModelConfig(PretrainedConfig):
|
||||
max_seq_len: int = 512,
|
||||
initializer_range: float = 0.02,
|
||||
use_rope=True,
|
||||
use_fast_ffn: bool = False,
|
||||
rope_theta: int = 10000,
|
||||
rope_3d: bool = False,
|
||||
ori_vocab_size: int | None = None,
|
||||
@@ -104,7 +103,6 @@ class ModelConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.use_rope = use_rope
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.rope_theta = rope_theta
|
||||
self.ori_vocab_size = ori_vocab_size or vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
@@ -199,7 +197,7 @@ class ParallelConfig:
|
||||
eos_tokens_lens: int = 2
|
||||
# Enable chunked prefill
|
||||
enable_chunked_prefill: str = "store_true"
|
||||
#
|
||||
|
||||
max_num_batched_tokens: int = 2048
|
||||
# enable prefix cache
|
||||
enable_prefix_caching = None
|
||||
@@ -349,7 +347,7 @@ class GraphOptimizationConfig:
|
||||
class LoadConfig:
|
||||
"""
|
||||
Configuration for dynamic weight loading strategies
|
||||
|
||||
|
||||
Attributes:
|
||||
dynamic_load_weight: Whether to enable dynamic weight loading
|
||||
load_strategy: Specifies the weight loading method when enabled:
|
||||
@@ -366,7 +364,7 @@ class LoadConfig:
|
||||
def __post_init__(self):
|
||||
if self.load_strategy is not None and not self.dynamic_load_weight:
|
||||
raise ValueError("Load strategy requires dynamic_load_weight=True")
|
||||
|
||||
|
||||
if self.dynamic_load_weight and self.load_strategy is None:
|
||||
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
|
||||
|
||||
|
@@ -728,7 +728,7 @@ class Config:
|
||||
), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
import xgrammar
|
||||
import xgrammar # noqa
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
|
@@ -12,16 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .attention import Attention
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
from .mla_attention_backend import MLAAttentionBackend
|
||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||
from .xpu_attn_backend import XPUAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
||||
"MLAAttentionBackend"
|
||||
"MLAAttentionBackend", "FlashAttentionBackend"
|
||||
]
|
||||
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
243
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
243
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
pre_cache_num_blocks_cpu = None
|
||||
kv_token_num_cpu = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
||||
fd_config.parallel_config.expert_parallel_rank
|
||||
if self.device_id is None:
|
||||
self.device_id = device_id
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[device_id]
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.set_max_lengths[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0],
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)
|
||||
res = flash_attention_v3_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.set_max_lengths[0],
|
||||
max_seqlen_k=metadata.set_max_lengths[3],
|
||||
causal=self.causal,
|
||||
)[0].reshape([-1, self.hidden_size])
|
||||
return res
|
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
@@ -17,10 +17,16 @@
|
||||
from .append_attention import append_attention
|
||||
from .get_block_shape_and_split_kv_block import \
|
||||
get_block_shape_and_split_kv_block
|
||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||
from .init_signal_layerwise import init_signal_layerwise
|
||||
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
|
||||
from .pre_cache_len_concat import pre_cache_len_concat
|
||||
|
||||
__all__ = [
|
||||
"get_block_shape_and_split_kv_block", "append_attention",
|
||||
"open_shm_and_get_meta_signal", "init_signal_layerwise"
|
||||
"get_block_shape_and_split_kv_block",
|
||||
"append_attention",
|
||||
"open_shm_and_get_meta_signal",
|
||||
"init_signal_layerwise",
|
||||
"gqa_rope_write_cache",
|
||||
"pre_cache_len_concat",
|
||||
]
|
||||
|
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def gqa_rope_write_cache(
|
||||
qkv: paddle.Tensor,
|
||||
key_cache: paddle.Tensor,
|
||||
value_cache: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
cu_seqlens_k: paddle.Tensor,
|
||||
rotary_embs: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
padding_offsets: paddle.Tensor,
|
||||
cum_offsets: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
kv_batch_ids: paddle.Tensor,
|
||||
kv_tile_ids_per_batch: paddle.Tensor,
|
||||
kv_num_blocks: paddle.Tensor,
|
||||
cache_batch_ids: paddle.Tensor,
|
||||
cache_tile_ids_per_batch: paddle.Tensor,
|
||||
cache_num_blocks: paddle.Tensor,
|
||||
cache_k_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_zp: Optional[paddle.Tensor] = None,
|
||||
cache_v_zp: Optional[paddle.Tensor] = None,
|
||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
cache_quant_type: str = "none"):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||
q, k, v, qkv_ = gqa_rope_write_cache(
|
||||
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
|
||||
rotary_embs, seq_lens_this_time, seq_lens_encoder,
|
||||
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
|
||||
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
|
||||
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
|
||||
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
|
||||
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
|
||||
kv_token_num, max_seq_len, cache_quant_type)
|
||||
return q, k, v, qkv_
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License,
|
||||
Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
max_dec_len: int = 0,
|
||||
block_size: int = 64):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
|
||||
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
|
||||
max_dec_len, block_size)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
@@ -329,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
activation: str = "gelu",
|
||||
use_fast_ffn: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -344,11 +343,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
activation (str): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool): Whether to use a faster FFN implementation.
|
||||
Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.activation = activation
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
@@ -385,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"gate_proj")
|
||||
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
converted_bias_tensor = paddle.zeros(shape=list(
|
||||
bias_tensor.shape),
|
||||
dtype=bias_tensor.dtype)
|
||||
if not self.use_fast_ffn:
|
||||
converted_bias_tensor = paddle.concat(
|
||||
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
|
||||
else:
|
||||
converted_bias_tensor = bias_tensor
|
||||
state_dict[self.bias_key] = converted_bias_tensor
|
||||
|
||||
if not self.use_fast_ffn:
|
||||
converted_weight_tensor = paddle.concat(
|
||||
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
|
||||
else:
|
||||
converted_weight_tensor = weight_tensor
|
||||
state_dict[self.bias_key] = bias_tensor
|
||||
|
||||
state_dict[self.weight_key] = converted_weight_tensor
|
||||
state_dict[self.weight_key] = weight_tensor
|
||||
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
@@ -15,8 +15,9 @@
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..attention import Attention
|
||||
from ..moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
@@ -214,7 +214,7 @@ def load_tp_checkpoint_v1(
|
||||
need_tp = True if tensor_parallel_filtered_map else False
|
||||
state_dict = {}
|
||||
for key, weight in weights_iterator:
|
||||
paddle.device.cuda.synchronize()
|
||||
paddle.device.synchronize()
|
||||
if need_tp and key in tensor_parallel_filtered_map:
|
||||
action = tensor_parallel_filtered_map.pop(key)
|
||||
tensor = action(weight).clone()
|
||||
|
@@ -86,7 +86,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
if isinstance(v, paddle.Tensor):
|
||||
v.value().get_tensor()._clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.cuda.synchronize()
|
||||
paddle.device.synchronize()
|
||||
|
||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
context = paddle.LazyGuard()
|
||||
|
@@ -28,7 +28,7 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
|
||||
@@ -68,7 +68,6 @@ class DeepSeekV3MLP(nn.Layer):
|
||||
output_size=intermediate_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -62,7 +62,6 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
output_size=intermediate_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -55,7 +55,6 @@ class Qwen2MLP(nn.Layer):
|
||||
output_size=fd_config.model_config.ffn_hidden_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -57,7 +57,6 @@ class Qwen3MLP(nn.Layer):
|
||||
output_size=fd_config.model_config.ffn_hidden_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
|
@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
|
||||
NATIVE_ATTN = enum.auto()
|
||||
APPEND_ATTN = enum.auto()
|
||||
MLA_ATTN = enum.auto()
|
||||
FLASH_ATTN = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
|
@@ -13,9 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
"""
|
||||
cuda platform file
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -65,6 +62,11 @@ class CUDAPlatform(Platform):
|
||||
return (
|
||||
"fastdeploy.model_executor.layers.attention.MLAAttentionBackend"
|
||||
)
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using FLASH ATTN backend.")
|
||||
return (
|
||||
"fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
|
@@ -90,7 +90,8 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model = get_model_from_loader(self.cfg)
|
||||
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to model_inputs"""
|
||||
max_dec_len = expected_decode_len + 1
|
||||
self.num_gpu_blocks = self.parallel_config.max_block_num
|
||||
@@ -130,10 +131,10 @@ class MTPProposer(Proposer):
|
||||
self.cache_kvs = {}
|
||||
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
if (self.quant_config and
|
||||
hasattr(self.quant_config, "kv_cache_quant_type") and
|
||||
self.quant_config.kv_cache_quant_type is not None):
|
||||
|
||||
if (self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None):
|
||||
cache_type = 'uint8'
|
||||
|
||||
# Get kv cache shape
|
||||
@@ -190,8 +191,7 @@ class MTPProposer(Proposer):
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend(
|
||||
self.parallel_config.attention_backend)
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
self.cfg,
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
@@ -200,8 +200,8 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
f"{ self.parallel_config.attention_backend} attention backend"
|
||||
" is not support by GPUModelRunner")
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def clear_dummy_input(self):
|
||||
|
@@ -593,7 +593,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.model = get_model_from_loader(fd_config=self.fd_config)
|
||||
# 1.1 Load RL dynamic model
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
from fastdeploy.rl.dynamic_weight_manager import \
|
||||
DynamicWeightManager
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
||||
|
||||
# 2. Load lora model
|
||||
@@ -622,7 +623,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cached data from shared inputs and forward metadata."""
|
||||
self.share_inputs.pop("caches", None)
|
||||
@@ -719,7 +720,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
"Attention backend which you chose is not support by GPUModelRunner"
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
@@ -1150,7 +1151,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.clear_dummy_input()
|
||||
# paddle.device.cuda.synchronize()
|
||||
|
||||
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
|
@@ -590,7 +590,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
head_dim=head_dim)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
"Attention backend which you chose is not support by GPUModelRunner"
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
|
Reference in New Issue
Block a user