Compare commits

...

31 Commits

Author SHA1 Message Date
Jiang-Jia-Jun
3ec126dc02 Update setup.py 2025-07-15 14:57:40 +08:00
gaoziyuan
337d76f094 [sync fix] (#2759)
* add rl qwen model support

* fix

* fix

* add_commit_config

* fix
2025-07-08 19:29:23 +08:00
gaoziyuan
ae2f78184d 【Sync develop】 add commit info (#2755)
* add rl qwen model support

* fix

* fix

* add_commit_config
2025-07-08 17:02:50 +08:00
gaoziyuan
6851489425 【Sync】Release/2.0.1 (#2745)
* add rl qwen model support

* fix

* fix
2025-07-08 14:38:18 +08:00
Jiang-Jia-Jun
ea787d8f62 fix bug. (#2718) (#2720)
Co-authored-by: Ting <wtmlon@foxmail.com>
2025-07-05 09:00:01 +08:00
Ting
90ef28d982 spec token map lazy. (#2715)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-05 00:14:54 +08:00
YuBaoku
b37585e693 [BugFix] fix paddle_git_commit_id error (#2714)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* set git identity to avoid merge failure in CI

* add ci cases

* [CI] Add validation for MTP and CUDAGraph

* [BugFix] fix paddle_git_commit_id error
2025-07-04 22:16:37 +08:00
lizexu123
9cb08e71e8 add support QWQ enable_thinking (#2706)
* add support QWQ enable_thinking

* add stream=True

* fix stream=true

* fix qwen

---------

Co-authored-by: lizexu <lizexu@baidu.com>
2025-07-04 20:55:23 +08:00
YuBaoku
dacc46f04c [CI] Add validation for MTP and CUDAGraph (#2710)
* set git identity to avoid merge failure in CI

* add ci cases

* [CI] Add validation for MTP and CUDAGraph
2025-07-04 18:13:54 +08:00
Jiang-Jia-Jun
09ded7715f Update mkdocs.yml 2025-07-04 17:55:52 +08:00
LQX
11cfdf5d89 添加XPU CI, test=model (#2701)
* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model
2025-07-04 16:16:06 +08:00
GoldPancake
e7fa57ebae Extract eh_proj Layer from ParallelLMHead for MTP to Avoid Weight Transposition Issue (#2707)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix mtp eh_proj layer

* fix mtp update_cfg function

* fix stringdoc

* simplify class name
2025-07-04 14:15:04 +08:00
gaoziyuan
a5ae88ded9 [feature]add fd whl version info (#2698) 2025-07-04 14:12:42 +08:00
ltd0924
87e638498c [RL] update reschedule finish reason (#2709) 2025-07-04 13:47:36 +08:00
freeliuzc
667547be59 support chunk_prefill in MTP (#2705) 2025-07-04 11:55:48 +08:00
LiqinruiG
b38823bc66 modify reasoning_output docs (#2696) 2025-07-04 11:30:02 +08:00
Divano
050d9658a5 Update requirements.txt 2025-07-04 09:53:03 +08:00
Divano
be5cabaf80 add quick benchmark (#2703)
测试脚本不需要过CI
2025-07-04 09:32:36 +08:00
Yuanle Liu
240bdac2a4 [feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* delete use_fast_ffn
2025-07-03 22:33:27 +08:00
ltd0924
00863c43fd [Bug] fix logger format (#2689)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-03 19:58:03 +08:00
kevin
3d3bccdf79 [doc] update docs (#2690) 2025-07-03 19:33:19 +08:00
Jiang-Jia-Jun
9fd74f75bd Update dynamic_weight_manager.py 2025-07-03 15:55:22 +08:00
Jiang-Jia-Jun
05c670e593 [Sync] Update to latest code (#2679)
* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-03 15:43:53 +08:00
Jiang-Jia-Jun
d222248d00 Update README.md 2025-07-03 15:28:28 +08:00
Jiang-Jia-Jun
e5b94d4117 Update README.md 2025-07-03 15:28:05 +08:00
Jiang-Jia-Jun
87e2e58a22 Update gh-pages.yml 2025-07-03 15:26:21 +08:00
Jiang-Jia-Jun
de20e5a992 Update Dockerfile.xpu
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-03 10:14:50 +08:00
Jiang-Jia-Jun
2f9c0618f0 Update Dockerfile.gpu 2025-07-03 10:14:39 +08:00
Yuanle Liu
9a14ab6572 add --force-reinstall --no-cache-dir when pip install fastdeploy*.whl (#2682)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-02 05:32:20 -07:00
Divano
d1cb3ed571 Update gh-pages.yml (#2680) 2025-07-02 17:36:18 +08:00
handiz
b8a8a19689 add wint2 performance (#2673) 2025-07-02 17:10:01 +08:00
134 changed files with 12860 additions and 1874 deletions

83
.github/workflows/ci_xpu.yml vendored Normal file
View File

@@ -0,0 +1,83 @@
name: CI_XPU
on:
pull_request:
branches: [ develop ]
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}-xpu-ci
cancel-in-progress: true
jobs:
build:
runs-on: [self-hosted, XPU-P800-8Card]
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
# Because the system version is lower than 2.23, the checkout cannot be used.
# - name: Checkout code
# uses: actions/checkout@v4
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}
fi
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
git merge pr/${{ github.event.pull_request.number }}
git log -n 3 --oneline
else
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
if [[ "$last_char" =~ [0-3] ]]; then
gpu_id="$last_char"
else
gpu_id="0"
fi
FD_API_PORT=$((9180 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
FD_METRICS_PORT=$((9170 + gpu_id * 100))
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
-v $(pwd):/workspace -w /workspace \
-v "/ssd3:/ssd3" \
-e "MODEL_PATH=/ssd3/model" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci_xpu.sh
"

View File

@@ -3,8 +3,6 @@ name: Deploy GitHub Pages
on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]
permissions:
contents: write
@@ -21,4 +19,6 @@ jobs:
- name: Deploy to GitHub Pages
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: mkdocs gh-deploy --force --remote-name origin
run: |
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
mkdocs gh-deploy --force --remote-name origin

View File

@@ -5,12 +5,6 @@ default_stages:
- pre-commit # Run locally
# - manual # Run in CI
repos:
# 格式化
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
# 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
@@ -29,15 +23,6 @@ repos:
rev: 6.0.1
hooks:
- id: isort
# # 格式化
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v20.1.3
# hooks:
# - id: clang-format
# # exclude: '.*'
# types_or: [c++, cuda]
# args: [--style=file, --verbose]
# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29

File diff suppressed because it is too large Load Diff

View File

@@ -3,3 +3,4 @@ tqdm
numpy
Pillow
pyyaml
requests

View File

@@ -0,0 +1,3 @@
metadata:
min_tokens: 32
max_tokens: 33

View File

@@ -166,7 +166,7 @@ function build_and_install() {
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
cd $DIST_DIR
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
if [ $? -ne 0 ]; then
cd ..
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
@@ -176,6 +176,21 @@ function build_and_install() {
cd ..
}
function version_info() {
output_file="fastdeploy/version.txt"
fastdeploy_git_commit_id=$(git rev-parse HEAD)
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
echo "Paddle version: $paddle_version" >> $output_file
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
echo "CUDA version: $cuda_version" >> $output_file
echo "CXX compiler version: $cxx_version" >> $output_file
}
function cleanup() {
rm -rf $BUILD_DIR $EGG_DIR
if [ `${python} -m pip list | grep fastdeploy | wc -l` -gt 0 ]; then
@@ -207,6 +222,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
set -e
init
version_info
build_and_install_ops
build_and_install
cleanup
@@ -237,6 +253,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
else
init
build_and_install_ops
version_info
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
fi

View File

@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle
from . import jit
from .jit_kernels import (
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
index fcb377e..db9d6f3 100644
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
import subprocess
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
def run_cuobjdump(file_path):
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
index 66c370a..4761426 100644
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
-import torch
+import paddle
from typing import Optional
from .template import map_ctype
@@ -35,7 +35,7 @@ class Runtime:
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
-import torch
+import paddle
from typing import Any, Dict, Iterable, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
}
@@ -27,25 +27,25 @@ genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
}
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
- if value.dtype == torch.int:
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
+import paddle
from functools import lru_cache
from typing import Tuple
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor) -> None:
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow paddle operations.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape
- assert n % 64 == 0 and k % 128 == 0
+ # assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
- assert m == m_ and n == n_ and k == k_
- assert n > 0 and k > 0
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
-import torch
+import paddle
from typing import Tuple
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
"""
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow Pypaddle operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
Values of `m_indices` in every-m-alignment-block must also be the same.
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
m__ = m_indices.numel()
# Type and shape checks
- assert m == m_ == m__ and k == k_ and n == n_
- assert lhs_scales.shape == (m, (k + 127) // 128)
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
+ # assert m_indices.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and m_indices.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
)
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
runtime(*args)
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow paddle operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups___ = masked_m.numel()
# Type and shape checks
- assert num_groups == num_groups_ == num_groups__ == num_groups___
- assert m == m_ and n == n_ and k == k_
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
+ # assert masked_m.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
- torch.cuda.current_stream(), num_sms, smem_config[0])
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
-import torch
+import paddle
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
@@ -51,10 +51,10 @@ class JITTuner:
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle
_num_sms = None
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
_num_sms = num_sms
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
"""
global _num_sms
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
return _num_sms
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
return ceil_div(x, alignment) * alignment
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
"""
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
+ aligned_x = paddle.transpose(
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
-import torch.distributed as dist
+import paddle
+import paddle.distributed as dist
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
+ paddle.device.cuda.synchronize()
+ paddle.device.synchronize()
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
cache.zero_()
# Warmup
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
x @ y
# Testing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
end_event.record()
- torch.cuda.synchronize()
+ paddle.device.synchronize()
return start_event.elapsed_time(end_event) / num_tests
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
fn()
if not using_nsys:
--
2.43.0
if not using_nsys:
--
2.43.0

View File

@@ -0,0 +1,236 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "multi_head_latent_attention_kernel.h"
template <size_t vec_size, typename T>
struct softmax_state_t {
AlignedVector<T, vec_size> o;
T m;
T d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = __float2half(-5e4f);
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = __float2bfloat16(-3.38953e38f);
}
}
__device__ __forceinline__ softmax_state_t() {
init();
}
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
T other_m,
T other_d) {
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
T m_prev = m, d_prev = d;
m = m_prev > other_m ? m_prev : other_m;
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1 + other_o[i] * scale2;
}
}
__device__ __forceinline__ void normalize() {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d;
}
}
};
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
struct softmax_state_ts {
uint32_t num_tiles_ = num_tiles;
AlignedVector<T, vec_size> o[num_tiles];
float m;
float d;
__device__ __forceinline__ void init() {
#pragma unroll
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
}
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ softmax_state_ts() {
init();
}
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
#pragma unroll
for (size_t i = 0; i < vec_size; i++) {
o[tile_id][i] /= d;
}
}
};
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
__device__ __forceinline__ void produce_kv(CacheT *smem,
CacheT *kv_base_gptr,
const int * block_table_smem,
const uint32_t seq_offset_gmem,
const uint32_t seq_offset_smem,
const uint32_t kv_head_idx,
const uint32_t kv_num_heads,
const uint32_t tidx,
const uint32_t chunk_start,
const uint32_t chunk_end) {
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
if (block_id < 0) {
block_id = 0;
}
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
// 8/16 T/int8 each time
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
seq_offset_gmem < chunk_end
);
}
}
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
const CacheT* k_smem,
const uint32_t kv_idx_base,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t tidx,
const uint32_t gid,
const float scale,
float *s,
softmax_state_ts<vec_size, T, num_tile_v>& st) {
const CacheT* smem;
AlignedVector<T, vec_size> q_vec;
AlignedVector<T, vec_size> k_vec;
float m_prev = st.m;
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
if (iter_base + j < iter_bound) {
if constexpr (std::is_same<T, half>::value) {
s[j] = 0.f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = 0.f;
}
#pragma unroll
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
for (uint32_t i = 0; i < vec_size; ++i) {
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
}
}
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
}
__syncthreads();
} else {
if constexpr (std::is_same<T, half>::value) {
s[j] = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = -3.38953e38f;
}
}
st.m = st.m > s[j] ? st.m : s[j];
}
// T o_scale = hexp(m_prev - st.m);
float o_scale = __expf(m_prev - st.m);
st.d *= o_scale;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
// s[j] = hexp(s[j] - st.m);
s[j] = __expf(s[j] - st.m);
st.d += s[j];
}
#pragma unroll
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
for (uint32_t i = 0; i < vec_size; ++i) {
st.o[tile_id][i] *= o_scale;
}
}
}
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
__device__ __forceinline__ void compute_sv(const float *s,
const CacheT *base_v_smem,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t tidx,
softmax_state_ts<vec_size, T, num_tile>& st) {
const CacheT* v_smem;
AlignedVector<T, vec_size> v_vec;
#pragma unroll
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
uint32_t tile_id = vid / bdx;
#pragma unroll
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
}
}
}
}

View File

@@ -0,0 +1,560 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decode_attention_func.cuh"
#define CHECK(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line %d:\n", __LINE__); \
printf(" Error code:%d\n", error_code); \
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
}while(0)
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cum_offsets,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
const float in_scale,
const int num_chunks,
const int chunk_size,
const int max_seq_len,
const int num_heads,
const int head_dim) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int qid = blockIdx.x, hid = blockIdx.y;
const int seq_len_q = seq_lens_q[qid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[qid];
if (seq_len_kv == 0) return;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
return;
}
__shared__ T smem[bdy * HEAD_DIM];
__shared__ T md_smem[bdy * 2];
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
T m;
T d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = __float2half(-5e4f);
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = __float2bfloat16(-3.38953e38f);
}
// merge per ty
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
T m_prev = m;
T d_prev = d;
const T m_now = multi_m[offset];
const T d_now = multi_d[offset];
m = m_prev > m_now ? m_prev : m_now;
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
d = d * scale1 + d_now * scale2;
#pragma once
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
// merge bdy
softmax_state_t<vec_size, T> st{};
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
#pragma once
for (int i = 0; i < iter_num; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
AlignedVector<OutT, vec_size> out_vec;
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
out_vec[i] = static_cast<OutT>(st.o[i]);
}
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
}
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
CacheT * __restrict__ cache_v,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cum_offsets,
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float in_scale,
const uint32_t chunk_size,
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
OutT * __restrict__ out) {
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t bid = bidx, gid = threadIdx.y;
const uint32_t tidx = threadIdx.x;
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
const int *block_table_now = block_table + bid * max_block_num_per_seq;
const uint32_t num_chunks = gridDim.y;
const uint32_t chunk_id = blockIdx.y;
const uint32_t q_len = seq_lens_q[bid];
if (q_len <= 0) {
return;
}
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
if (kv_len <= 0) {
return;
}
kv_len += q_len;
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (chunk_id >= num_chunk_this_seq) {
return;
}
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
const uint32_t chunk_len = chunk_end - chunk_start;
extern __shared__ uint8_t smem[];
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
#pragma unroll
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
}
__syncthreads();
using VecT = AlignedVector<T, VEC_SIZE>;
VecT q_vec;
#pragma unroll
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
q_vec[i] *= scale;
}
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
}
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
uint32_t stage_idx = 0;
constexpr int loop_times = DEAL_EACH_TIME / bdy;
#pragma unroll
for (int i = 0; i < NUM_STAGES; ++i) {
#pragma unroll
for (int j = 0; j < loop_times; ++j) {
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
const uint32_t k_seq_id = chunk_start + k_seq_offset;
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
kv_smem,
cache_k,
block_table_now,
k_seq_id,
k_seq_offset,
kv_head_idx,
kv_num_heads,
tidx,
chunk_start,
chunk_end
);
}
commit_group();
stage_idx = (stage_idx + 1) % NUM_STAGES;
}
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
float s[DEAL_EACH_TIME];
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
for (int iter = 0; iter < num_iters; ++iter) {
wait_group<NUM_STAGES - 1>();
__syncthreads();
// compute qk
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
cu_q_smem,
kv_smem,
chunk_start + iter * DEAL_EACH_TIME,
stage_idx,
iter * DEAL_EACH_TIME,
chunk_len,
tidx,
gid,
scale,
s,
st
);
__syncthreads();
// compute sv
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
s,
kv_smem,
stage_idx,
iter * DEAL_EACH_TIME,
chunk_len,
tidx,
st
);
__syncthreads();
#pragma unroll
for (int j = 0; j < loop_times; ++j) {
const uint32_t k_seq_offset = j * bdy + gid;
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
kv_smem,
cache_k,
block_table_now,
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
stage_idx * DEAL_EACH_TIME + k_seq_offset,
kv_head_idx,
kv_num_heads,
tidx,
chunk_start,
chunk_end
);
}
commit_group();
stage_idx = (stage_idx + 1) % NUM_STAGES;
}
wait_group<0>();
__syncthreads();
// normize if not partition_kv
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
const uint32_t tile_id = vid / bdx;
if (!partition_kv || num_chunk_this_seq == 1) {
st.normalize(tile_id);
}
if (partition_kv && num_chunk_this_seq > 1) {
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
tmp_m[head_idx] = st.m;
tmp_d[head_idx] = st.d;
} else {
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
}
}
}
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
void MultiQueryDecoderAttention(
const AppendAttnMetaData& meta_data,
cudaStream_t &stream,
const paddle::Tensor &q,
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const int max_seq_len,
const int max_dec_len,
const float rope_scale,
const float rope_theta,
const float softmax_scale,
const float in_scale,
paddle::Tensor *out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto token_num = meta_data.token_nums;
auto bsz = meta_data.batch_size;
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
constexpr int num_stages = NUM_STAGE;
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
constexpr int blocky = GROUP_SIZE;
const int gridx = bsz;
constexpr int num_threads = blockx * blocky;
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
uint32_t cache_smem_bytes = 0;
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
const uint32_t chunk_size = get_max_partition_size(bsz);
const int num_chunks = div_up(max_dec_len, chunk_size);
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
assert(act_blocks_per_sm > 1);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
dim3 grids(gridx, num_chunks, kv_num_heads);
dim3 blocks(blockx, blocky);
if (num_chunks <= 1) {
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
cache_vec_size, blockx, blocky>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
softmax_scale,
in_scale,
chunk_size,
nullptr,
nullptr,
nullptr,
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
);
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
} else {
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
tmp_workspace = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads));
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
softmax_scale,
in_scale,
chunk_size,
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
);
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
constexpr int mblockx = HEAD_DIM_V / vec_size;
constexpr int bdy = 256 / mblockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(mblockx, bdy);
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
in_scale,
num_chunks,
chunk_size,
max_seq_len,
num_heads,
HEAD_DIM_V
);
}
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
}
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim_qk = meta_data.head_dims;
const auto head_dim_v = meta_data.head_dims_v;
const float rope_scale = 0.0;
const float rope_theta = 0.0;
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
const uint32_t num_stage = get_cascade_attention_num_stages();
const uint32_t num_threads = get_cascade_attention_num_threads();
DISPATCH_CAUSAL(causal, CAUSAL,
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);
template void DecodeMLAAttentionKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -0,0 +1,291 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mla_cache_kernel.cuh"
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
constexpr int PackSize = 16 / sizeof(DataType_);
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
return {};
}
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto bsz = meta_data.batch_size;
auto token_num = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
constexpr int PackSize = 16 / sizeof(DataType_);
const int blocksize = 128;
int grid_size = 1;
if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
} else {
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
}
return {};
}
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
PD_BUILD_OP(prefill_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_OP(decode_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"cum_offsets",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));

View File

@@ -0,0 +1,242 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "mem_util.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1>
__global__ void decode_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
start_token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
start_token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
template <typename T, int VecSize = 1>
__global__ void speculate_decode_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int write_seq_id =
seq_lens[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens[ori_bi],
token_id,
cum_offsets[ori_bi]);
}
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
token_id * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
token_id * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
template <typename T, int VecSize = 1>
__global__ void prefill_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
const uint32_t ori_bi = ori_token_idx / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "utils.cuh"
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
int kv_num_heads;
int token_nums;
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
};
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} \
}
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
if (num_stage == 2) { \
constexpr size_t NUM_STAGE = 2; \
__VA_ARGS__ \
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
case 576: { \
constexpr size_t HEAD_DIM = 576; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
if (num_stage == 2) { \
constexpr size_t NUM_STAGE = 2; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_stage: ", num_stage); \
}
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
constexpr size_t cache_bytes = 4; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the cache_type: ", cache_type); \
}
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
if (deal_each_time == 32) { \
if (deal_each_time == 32) { \
constexpr size_t DEAL_EACH_TIME = 32; \
__VA_ARGS__ \
} else if (deal_each_time == 64) { \
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 128) { \
constexpr size_t GROUP_SIZE = 128; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
if (block_shape_q <= 16) { \
constexpr size_t BLOCK_SHAPE_Q = 16; \

View File

@@ -316,6 +316,96 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
int64_t num_experts);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch);
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder);
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len);
void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox);
std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder);
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
@@ -370,6 +460,14 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
paddle::Tensor const &input,
paddle::Tensor &scales, float scale_ub);
std::vector<paddle::Tensor> NoauxTc(
paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -627,6 +725,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"),
py::arg("is_zp_float"));
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");
/**
@@ -653,4 +753,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
}

64
custom_ops/gpu_ops/env.h Normal file
View File

@@ -0,0 +1,64 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}
inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
}
inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
return max_partition_size;
}
inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}
inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}
inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}
inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}

View File

@@ -0,0 +1,146 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_kernel(
T* __restrict__ arr,
const T* __restrict__ cos_ptr,
const T* __restrict__ sin_ptr,
int rot_offset,
int embed_dim) {
int x_index, y_index;
T cos, sin;
if (IS_NEOX) {
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = cos_ptr[x_index];
sin = sin_ptr[x_index];
} else {
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = cos_ptr[x_index / 2];
sin = sin_ptr[x_index / 2];
}
const T x = arr[x_index];
const T y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template <typename T, bool IS_NEOX>
__global__ void apply_rotary_embedding_kernel(
T* __restrict__ query, // [num_tokens, num_heads, head_size]
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const int* __restrict__ position_ids, // [num_tokens]
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const T* cos_ptr = cache_ptr;
const T* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}
void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox) {
int64_t num_tokens = query.dims()[0];
int num_heads = query.numel() / num_tokens / head_size;
int num_kv_heads = key.numel() / num_tokens / head_size;
int rot_dim = cos_sin_cache.dims()[1];
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;
if (num_tokens > 65535) {
PD_THROW(
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
}
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
apply_rotary_embedding_kernel<data_t, true>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
PD_BUILD_OP(fused_rotary_position_encoding)
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
.Outputs({"query_out", "key_out"})
.Attrs({"head_size: int", "is_neox: bool"})
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
const int* seq_lens_this_time,
int* position_ids, // 输出的一维 position_ids
int* mask_encoder_batch,
const int bsz) { // 批次大小
// 当前线程索引(每个线程对应一个批次)
int tid = threadIdx.x;
if (tid >= bsz) return;
// 动态计算当前批次的偏移量
int offset = 0;
for (int i = 0; i < tid; i++) {
offset += seq_lens_encoder[i];
if (seq_lens_decoder[i] > 0) {
offset += seq_lens_this_time[i];
}
}
// 当前批次的 encoder 和 decoder 长度
int encoder_len = seq_lens_encoder[tid];
int decoder_len = seq_lens_decoder[tid];
int seq_len_this_time = seq_lens_this_time[tid];
// 写入 encoder 的 position_ids
for (int i = 0; i < encoder_len; i++) {
position_ids[offset + i] = i;
mask_encoder_batch[offset + i] = 1;
}
offset += encoder_len;
// 写入 decoder 的 position_ids
if (decoder_len > 0) {
for (int i = 0; i < seq_len_this_time; i++) {
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
mask_encoder_batch[offset + i] = 0;
}
}
}
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch) {
const int bsz = seq_lens_this_time.shape()[0];
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
const_cast<int*>(position_ids.data<int>()),
const_cast<int*>(mask_encoder_batch.data<int>()),
bsz);
}
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
"mask_encoder_batch"})
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"},
{"mask_encoder_batch", "mask_encoder_batch_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));

View File

@@ -39,10 +39,12 @@ namespace cub = hipcub;
#include <fstream>
#include <iostream>
#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/cuda_stream.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -513,3 +515,10 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
return max_shared_mem_per_block_opt_in;
}
inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
}

View File

@@ -0,0 +1,255 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#include <cute/tensor.hpp>
#include <cutlass/detail/helper_macros.hpp>
#include "utils.cuh"
namespace mla_attn {
using namespace cute;
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
};
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
};
template <int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator& op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template <>
struct Allreduce<2> {
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator& op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& summary, Operator& op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
Tensor<Engine1, Layout1>& src, Operator& op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++) {
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& summary, Operator& op) {
thread_reduce_<init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& max) {
MaxOp<float> max_op;
reduce_<init>(tensor, max, max_op);
}
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& sum) {
SumOp<float> sum_op;
thread_reduce_<init>(tensor, sum, sum_op);
if constexpr (warp_reduce) {
quad_allreduce_(sum, sum, sum_op);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
Tensor<Engine1, Layout1> const& max) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
auto row_max = max(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
Tensor<Engine1, Layout1> const& max,
const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
auto row_max = max(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// row_max * scale is a constant for each row, so we can use fma here
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
}
}
}
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
struct OnlineSoftmax {
constexpr static float fill_value = -5e4;
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
TensorT row_max, row_sum, scores_scale;
float sm_scale_log2;
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
clear(scores_scale);
};
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
template <bool init, typename Tensor0>
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
if constexpr (init) {
reduce_max</*init=*/true>(scores, row_max);
if constexpr (WITH_SCALE) {
scale_apply_exp2(scores, row_max, sm_scale_log2);
} else {
apply_exp2(scores, row_max);
}
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
} else {
// update row_max
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*init=*/false>(scores, row_max);
// update scores_scale and scale row_sum
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
if constexpr (WITH_SCALE) {
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
} else {
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
}
row_sum(mi) *= scores_scale(mi);
}
// perform exp2 on scores
if constexpr (WITH_SCALE) {
scale_apply_exp2(scores, row_max, sm_scale_log2);
} else {
apply_exp2(scores, row_max);
}
// update row_sum
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
return scores_scale;
}
};
template <typename Tensor0>
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.f / sum;
scores_scale(mi) = inv_sum;
row_max(mi) *= sm_scale_log2;
}
return scores_scale;
};
template <typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
template <typename Tensor1, typename Tensor2>
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
}
}
};
};
} // namespace mla_attn

View File

@@ -0,0 +1,235 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "cute/tensor.hpp"
#include "mla_hopper.cuh"
#include <iostream>
#include <string>
#include <sstream>
#include "batch_mla_with_paged_kv_cache.h"
#include "env.h"
using namespace cute;
using namespace mla_attn;
using namespace std;
template <typename T>
struct cascade_type_traits {
using type = T;
using cutlass_type = T;
};
template <>
struct cascade_type_traits<phi::dtype::bfloat16> {
using type = __nv_bfloat16;
using cutlass_type = cutlass::bfloat16_t;;
};
template <>
struct cascade_type_traits<phi::dtype::float16> {
using type = half;
using cutlass_type = cutlass::half_t;
};
template <>
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
using type = __nv_fp8_e4m3;
using cutlass_type = cutlass::float_e4m3_t;
};
template <typename T>
void BatchMLAWithPagedKVCacheKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out) {
using NV_TYPE = typename cascade_type_traits<T>::type;
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto q_head_num = meta_data.q_num_heads;
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
const auto max_block_num = bsz * max_block_num_per_seq;
const uint32_t chunk_size = get_max_partition_size(bsz);
int q_head_dim = meta_data.head_dims;
int k_head_dim = meta_data.head_dims;
int v_head_dim = meta_data.head_dims_v;
// int num_chunks = max_dec_len / chunk_size;
int num_chunks = div_up(max_dec_len, chunk_size);
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
O_tmp = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
m_tmp = allocator->Allocate(
sizeof(float) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
d_tmp = allocator->Allocate(
sizeof(float) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
params.m = reinterpret_cast<float*>(m_tmp->ptr());
params.d = reinterpret_cast<float*>(d_tmp->ptr());
params.block_tables = const_cast<int*>(block_tables.data<int>());
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
params.num_blocks_x_int = num_blocks_x;
params.q_stride_bsz = q_head_num * q_head_dim;
params.q_stride_head_num = q_head_dim;
params.kv_stride_block_num = block_size * k_head_dim;
params.kv_stride_block_size = k_head_dim;
params.o_stride_bsz = q_head_num * v_head_dim;
params.o_stride_head_num = v_head_dim;
params.bsz = bsz;
params.token_num = token_num;
params.max_seq_len = max_seq_len;
params.max_block_num = max_block_num;
params.max_block_num_per_seq = max_block_num_per_seq;
params.q_num_head = q_head_num;
params.qk_head_dim = q_head_dim;
params.vo_head_dim = v_head_dim;
params.block_size = block_size;
params.max_draft_token_num = draft_token_num;
params.sm_scale = softmax_scale;
params.chunk_size = chunk_size;
params.chunk_num = num_chunks;
if (q_head_dim == 576) {
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
params, stream
);
} else {
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
}
}
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "paddle/extension.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/allocator.h"
#include "append_attn/utils.cuh"
template <typename T>
void BatchMLAWithPagedKVCacheKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -0,0 +1,175 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
#define ATTENTION_HOPPER_EPILOGUE_CUH_
#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename Ktraits>
struct CollectiveEpilogue {
using DTypeO = typename Ktraits::DTypeO;
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
using StrideNTMAT = cute::Shape<int32_t, _1>;
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
using TiledCopyOValLayout =
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
using TiledCopyO =
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
TiledCopyOValLayout{} // Val layout
));
struct Arguments {
DTypeO* O_ptr;
LayoutNTMAT const layout_O;
DTypeO* O_ptr_tmp;
LayoutNTMAT const layout_O_tmp;
};
// Device side kernel params
struct Params {
DTypeO* O_ptr;
LayoutNTMAT const layout_O;
DTypeO* O_ptr_tmp;
LayoutNTMAT const layout_O_tmp;
};
static Params to_underlying_arguments_ntma(Arguments const& args) {
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
typename TiledMma>
CUTLASS_DEVICE void store(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const int thread_idx,
const int bid,
const int bsz,
const int seq_len_now,
const int start_token_idx,
const int tile_idx,
const int kv_len,
const int chunk_size,
const int max_draft_token_num,
const int o_stride_bsz) {
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// make sure gemm done
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
// r2s
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
// make sure r2s done
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
TiledCopyO gmem_tiled_copy_O;
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
// copy if not out of bound
auto predicate_fn = [&](auto coords) {
auto s_coords = tOcOGroup(_0{}, coords);
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
};
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
}
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_

View File

@@ -0,0 +1,163 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
#include <type_traits>
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
namespace mla_attn {
using namespace cute;
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
struct alignas(16) SharedStorageQKVO {
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
union {
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
};
};
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
struct AttentionKernelTraits {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
using DTypeQKAccum = float;
using DTypePVAccum = float;
using NV_TYPE = NV_TYPE_;
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
static constexpr int GROUP_SIZE = GROUP_SIZE_;
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
static_assert(BLOCK_SHAPE_Q % 64 == 0);
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
static_assert(HEAD_DIM_QK % 32 == 0);
static_assert(HEAD_DIM_VO % 32 == 0);
static constexpr int NUM_WARPS = 12;
static constexpr int NUM_THREADS = 384;
static constexpr int NUM_PRODUCER_THREADS = 128;
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
static constexpr int NUM_STAGES = NUM_STAGES_;
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
using TiledMmaQK = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
using TiledMmaPV = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutPV{}));
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutPV{}));
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
using SmemLayoutVt = decltype(composition(
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVtOneStage = decltype(composition(
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
get<2>(TileShape_PDV{}), Int<1>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
decltype(cute::get<1>(TileShape_QKD{}))>());
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
using PipelineStateQ = typename cutlass::PipelineState<1>;
using MainloopPipeline =
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
typename cutlass::PipelineAsync<NUM_STAGES>>;
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
};
} // namespace mla_attn
#endif

View File

@@ -0,0 +1,348 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename Ktraits, bool CAUSAL>
struct CollectiveMainloop {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = float;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
GmemCopyAtomQ{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomQ = Layout<
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyQ = decltype(make_tiled_copy(
GmemCopyAtomQ{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using ShapeQT = cute::Shape<int32_t, int32_t>;
using StrideQT = cute::Shape<int32_t, _1>;
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeMDT = cute::Shape<int32_t, int32_t>;
using StrideMDT = cute::Shape<int32_t, _1>;
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)), StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_QKD{}),
_1{})); // no mcast for KV
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
static constexpr uint32_t TmaTransactionBytesQ =
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
static constexpr uint32_t TmaTransactionBytesKV =
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
// Host side kernel arguments
struct Arguments {
LayoutQT layout_Q;
LayoutT layout_KV;
LayoutMDT layout_MD;
DTypeQ const* Q_ptr;
DTypeKV const* KV_ptr;
DTypeMD const* m_ptr;
DTypeMD const* d_ptr;
IdType const* kv_block_tables;
IdType const* seq_lens_this_time;
IdType const* seq_lens_encoder;
IdType const* seq_lens_decoder;
IdType const* cumsum_q_seqlens;
IdType const* batch_ids;
IdType const* tile_ids_per_batch;
IdType const* num_blocks_x;
float sm_scale;
int bsz;
int max_block_num;
int max_block_num_per_seq;
int q_stride_bsz;
int q_stride_head_num;
int kv_stride_block_num;
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
int chunk_num;
int max_draft_token_num;
};
// Device side kernel params
struct Params {
LayoutQT layout_Q;
LayoutT layout_KV;
LayoutMDT layout_MD;
DTypeQ *Q_ptr;
DTypeKV* KV_ptr;
DTypeMD* m_ptr;
DTypeMD* d_ptr;
IdType* kv_block_tables;
IdType* seq_lens_this_time;
IdType* seq_lens_encoder;
IdType* seq_lens_decoder;
IdType* cumsum_q_seqlens;
IdType* batch_ids;
IdType* tile_ids_per_batch;
IdType* num_blocks_x;
float sm_scale;
int bsz;
int max_block_num;
int max_block_num_per_seq;
int q_stride_bsz;
int q_stride_head_num;
int kv_stride_block_num;
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
int chunk_num;
int max_draft_token_num;
TMA_KV tma_load_KV;
};
static Params to_underlying_arguments(Arguments const& args) {
TMA_KV tma_load_KV;
if constexpr (USE_TMA_LOAD_KV) {
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
tma_load_KV =
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
}
return {args.layout_Q,
args.layout_KV,
args.layout_MD,
const_cast<DTypeQ*>(args.Q_ptr),
const_cast<DTypeKV*>(args.KV_ptr),
const_cast<DTypeMD*>(args.m_ptr),
const_cast<DTypeMD*>(args.d_ptr),
const_cast<IdType*>(args.kv_block_tables),
const_cast<IdType*>(args.seq_lens_this_time),
const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_decoder),
const_cast<IdType*>(args.cumsum_q_seqlens),
const_cast<IdType*>(args.batch_ids),
const_cast<IdType*>(args.tile_ids_per_batch),
const_cast<IdType*>(args.num_blocks_x),
args.sm_scale,
args.bsz,
args.max_block_num,
args.max_block_num_per_seq,
args.q_stride_bsz,
args.q_stride_head_num,
args.kv_stride_block_num,
args.kv_stride_block_size,
args.o_stride_bsz,
args.o_stride_head_num,
args.chunk_size,
args.chunk_num,
args.max_draft_token_num,
tma_load_KV
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
if constexpr (USE_TMA_LOAD_KV) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
}
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_write_q,
SharedStorage& shared_storage,
const int thread_idx,
const int bid) {
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
Tensor gQ =
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor cQ = cute::make_identity_tensor(gQ.shape());
GmemTiledCopyQ gmem_tiled_copy_q;
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
Tensor tQcQGroup = flatten_1(tQcQ);
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
auto q_predicate_fn = [&](auto coords) {
auto s_coords = tQcQGroup(_0{}, coords);
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
};
Tensor tQgQiGroup = flatten_1(tQgQ);
Tensor tQsQiGroup = flatten_1(tQsQ);
pipeline_q.producer_acquire(smem_pipe_write_q);
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_q;
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_write_kv,
SharedStorage& shared_storage,
const int bid,
const int kv_len,
const int tile_idx) {
int thread_idx = threadIdx.x;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
GmemTiledCopy gmem_tiled_copy_kv;
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
const int block_idx = kv_block_tables(bid, kv_tile_idx);
pipeline_kv.producer_acquire(smem_pipe_write_kv);
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
Tensor tKsKiGroup =
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_kv;
}
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_write_kv,
SharedStorage& shared_storage,
const int bid,
const int kv_len,
const int tile_idx) {
int thread_idx = threadIdx.x;
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
// Prepare the TMA loads
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
auto [tKgK, tKsK] =
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
#pragma unroll 2
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
const int block_idx = kv_block_tables(bid, kv_tile_idx);
pipeline_kv.producer_acquire(smem_pipe_write_kv);
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
++smem_pipe_write_kv;
}
}
}
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_

View File

@@ -0,0 +1,500 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "named_barrier.cuh"
// #define DEBUG_MLA
namespace mla_attn {
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_read_q,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_read_kv,
FrgTensorO& tOrO,
AttentionUpdater& attention_updater,
const int thread_idx,
const int bid,
const int kv_len,
const int qo_len,
const int tile_idx,
SharedStorage& shared_storage) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutP = typename Ktraits::SmemLayoutP;
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
typename Ktraits::TiledMmaQK tiled_mma_qk;
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 1) {
// consumer 0, compute qk
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
bool is_first_step = true;
// wait q
consumer_wait(pipeline_q, smem_pipe_read_q);
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
#pragma unroll 1
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
// wait kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// gemm qk
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
// mask
if (masking_step > 0) {
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
// update s (exp(s - m))
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
is_first_step = false;
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
cute::copy(scale_o, tScalesScale);
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
// make sure r2s all done
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
attention_updater.rescale_o(tOrO, scale_o);
// pv gemm
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV1(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV2(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
// sync WG1 WG2
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
}
// release q
pipeline_q.consumer_release(smem_pipe_read_q);
++smem_pipe_read_q;
// normalize
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
if (chunk_num_this_seq == 1) {
// norm
cute::copy(scale_o, tScalesScale);
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
attention_updater.rescale_o(tOrO, scale_o);
}
// WG1 write m,d back to gmem
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8t4->row1 row9
const int warp_idx = thread_idx / 32;
#pragma unroll
for (int w_i = 0; w_i < 2; ++w_i) {
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
if (token_idx < qo_len) {
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
}
}
}
} else if (warp_group_idx == 2) {
// consumer 1, compute pv
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
// wait kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// A: tPsP
cute::copy(tScalesScale, scale_o);
// rescale
attention_updater.rescale_o(tOrO, scale_o);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV1(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV2(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
// sync WG1 WG2
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
}
if (chunk_num_this_seq == 1) {
// norm
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
cute::copy(tScalesScale, scale_o);
attention_updater.rescale_o(tOrO, scale_o);
}
}
return;
}
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_read_q,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_read_kv,
FrgTensorO& tOrO,
AttentionUpdater& attention_updater,
const int thread_idx,
const int bid,
const int kv_len,
const int qo_len,
const int tile_idx,
SharedStorage& shared_storage) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutP = typename Ktraits::SmemLayoutP;
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
typename Ktraits::TiledMmaQK tiled_mma_qk;
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 1) {
// consumer 0, compute qk
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
// wait q
consumer_wait(pipeline_q, smem_pipe_read_q);
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
// wait k
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// first qk gemm
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
// mask
{
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
--kv_tile_idx;
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
++smem_pipe_read_kv;
// wait next kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// gemm next qk
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
attention_updater.rescale_o(tOrO);
// last pv gemm
if (smem_pipe_read_kv_cur.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv_cur.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv_cur.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
// wait cur qk gemm
warpgroup_wait<1>();
// mask p
if (masking_step > 0) {
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
// update s (exp(s - m))
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
// make sure tSrS r2s done
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// wait last pv gemm
warpgroup_wait<0>();
// release last kv
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
}
// release q
pipeline_q.consumer_release(smem_pipe_read_q);
++smem_pipe_read_q;
// compute last pv
attention_updater.rescale_o(tOrO);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
scale_o = attention_updater.finalize(tSrS);
warpgroup_wait<0>();
// release last kv
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
if (chunk_num_this_seq == 1) {
// norm
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
attention_updater.rescale_o(tOrO);
}
// WG1 write m,d back to gmem
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8t4->row1 row9
const int warp_idx = thread_idx / 32;
#pragma unroll
for (int w_i = 0; w_i < 2; ++w_i) {
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
if (token_idx < qo_len) {
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
}
}
}
} else if (warp_group_idx == 2) {
// consumer 1, compute pv
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
consumer_wait(pipeline_kv, smem_pipe_read_kv);
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// A: tPsP
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
// rescale
attention_updater.rescale_o(tOrO, scale_o);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
}
if (chunk_num_this_seq == 1) {
// norm
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
attention_updater.rescale_o(tOrO, scale_o);
}
}
return;
}
} // namespace mla_attn
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_

View File

@@ -0,0 +1,575 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "attention_updater.cuh"
#include "cute/tensor.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "epilogue.cuh"
#include "helper.h"
#include "kernel_traits.cuh"
#include "mainloop_mma.cuh"
#include "mainloop_load.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
struct Params {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *padding_offsets;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
int bsz;
int token_num;
int max_seq_len;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_size;
int chunk_num;
int num_blocks_x_int;
float sm_scale;
};
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
const int num_blocks_x = mainloop_params.num_blocks_x[0];
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS;
} else {
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
MainloopPipeline pipeline_kv = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
}
}();
__syncthreads();
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
if (warp_group_idx == 0) {
// producer
if constexpr(USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_dealloc<72>();
}
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
const int block_id = blockIdx.x;
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
// consumer
if constexpr(USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_alloc<216>();
}
PipelineStateQ smem_pipe_read_q;
PipelineState smem_pipe_read_kv;
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
} else {
const int block_id = blockIdx.x;
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using NV_TYPE = typename KernelTraits::NV_TYPE;
using CollectiveMainloop =
CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_encoder,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_size,
params.chunk_num,
params.max_draft_token_num
});
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
// Get the ptr to kernel function.
auto kernel =
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
int device;
cudaGetDevice(&device);
int multiprocessor_count;
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
int act_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
int gridx;
if constexpr(USE_FIXED_BLOCK) {
gridx = multiprocessor_count;
} else {
gridx = params.num_blocks_x_int;
}
dim3 grid_dims = {gridx, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
mainloop_params, epilogue_params
);
if (params.chunk_num > 1) {
constexpr int vec_size = 16 / sizeof(DTypeO);
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.seq_lens_encoder,
params.padding_offsets,
reinterpret_cast<NV_TYPE*>(params.O),
params.max_seq_len,
params.chunk_num,
params.q_num_head,
params.chunk_size,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num
);
}
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
} else {
return cudaErrorNotSupported;
}
return cudaSuccess;
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
#include <cuda_runtime.h>
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
namespace mla_attn {
enum class NamedBarriers {
kQueryEmpty = 0,
kValueEmpty = 1,
kWarpSchedulerWG1 = 2,
kWarpSchedulerWG2 = 3,
kWarpSchedulerWG3 = 4,
kPrefetchIndices = 5,
kOdone = 6,
kWG1WG2Sync = 7,
kWG0WG1WG2Sync = 8,
kWG1WG2LastSync = 9,
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_

View File

@@ -0,0 +1,351 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_UTILS_CUH_
#define ATTENTION_HOPPER_UTILS_CUH_
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <stdlib.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cuda_runtime.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cmath>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "cutlass/fast_math.h"
namespace mla_attn {
using namespace cute;
template <typename TensorT>
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
Tensor tensor_flatten = cute::flatten(tensor);
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
}
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
int64_t h_stride) {
return make_layout(make_shape(nnz, head_dim, num_heads),
make_stride(n_stride, cute::_1{}, h_stride));
}
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
int head_idx, int offset, int seq_len) {
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
make_coord(offset, _0{}));
auto g_sequence =
make_tensor(g_offset.data(),
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
int head_idx, int offset, int seq_len) {
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
cute::make_shape(shape<0>(m_tensor))));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
return g_tensor;
}
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
// MMA_N))
template <typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
};
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
// MMA_N))
template <typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
make_layout(get<2, 1>(l), get<2>(acc_layout)));
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
TensorC& tCrC) {
constexpr bool Is_RS =
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
}
warpgroup_fence_operand(tCrC);
warpgroup_arrive();
if constexpr (init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
}
}
#define HOSTDEVICE __host__ __device__
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
};
template <typename T, int Size>
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
const AlignedVector<T, Size>* addr_vec =
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
*vec = *addr_vec;
}
template <typename T, int Size>
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
AlignedVector<T, Size>* addr_vec =
reinterpret_cast<AlignedVector<T, Size>*>(addr);
*addr_vec = vec;
}
template <size_t vec_size, typename T>
struct prefill_softmax_state_t {
AlignedVector<T, vec_size> o;
float m;
float d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
const float other_m,
const float other_d) {
float m_prev = m, d_prev = d;
m = max(m_prev, other_m);
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
}
}
__device__ __forceinline__ void normalize() {
const T d_t = static_cast<T>(d);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
}
}
};
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder,
const int * __restrict__ padding_offsets,
T * __restrict__ out, // [token_num, num_heads, head_dim]
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim,
const int token_num,
const int bsz,
const int max_draft_token_num) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t ori_token_id = qid + padding_offsets[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue;
const uint32_t local_seq_id = ori_token_id % max_seq_len;
int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
// not need merge
continue;
}
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset;
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
d = d * scale1 + d_now * scale2;
#pragma unroll
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
if (ty == 0) {
// merge bdy
prefill_softmax_state_t<vec_size, T> st;
st.init();
#pragma unroll
for (int i = 0; i < bdy; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
__syncthreads();
}
}
} // namespace mla_attn
#endif // ATTENTION_HOPPER_UTILS_CUH_

View File

@@ -1255,8 +1255,6 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {

View File

@@ -0,0 +1,469 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "append_attn/multi_head_latent_attention_kernel.h"
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
template <paddle::DataType D>
std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& cache_quant_type_str,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
typedef PDTraits<D> traits_;
typedef typename traits_::data_t data_t;
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];
const bool mla_use_tensorcore = get_mla_use_tensorcore();
auto sm_version = GetSMVersion();
if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) {
PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90.");
}
auto main_stream = query.stream();
paddle::Tensor fmha_out = paddle::full(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
0,
D,
query.place());
if (max_dec_len_this_time_data > 0) {
if (mla_use_tensorcore) {
BatchMLAWithPagedKVCacheKernel<data_t>(meta_data,
query,
key_cache,
attn_mask,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
cache_quant_type_str,
decoder_num_blocks_data,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
main_stream,
&fmha_out);
} else {
DecodeMLAAttentionKernel<data_t>(
meta_data,
query, // [token_num, num_heads, head_dim]
key_cache,
value_cache,
attn_mask,
out_linear_shifts,
out_linear_smooths,
seq_lens_this_time, // q_seq_len is 1
seq_lens_decoder,
padding_offsets,
cum_offsets,
block_tables,
max_input_length,
max_len_kv_data,
softmax_scale,
out_linear_in_scale,
causal,
main_stream,
&fmha_out);
}
}
return {fmha_out};
}
std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
AppendAttnMetaData meta_data;
const auto& query_dims = query.dims();
const auto& key_cache_dims = key_cache.dims();
const int q_hidden_size = query_dims[query_dims.size() - 1];
meta_data.token_nums = query_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = cum_offsets.dims()[0];
switch (query.dtype()) {
case paddle::DataType::BFLOAT16: {
return MultiHeadLatentAttentionKernel<paddle::DataType::BFLOAT16>(
meta_data,
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_cpu,
max_enc_len_this_time,
max_dec_len_this_time,
max_len_kv,
attn_mask,
query_bias,
query_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
cache_quant_type_str,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
}
case paddle::DataType::FLOAT16: {
return MultiHeadLatentAttentionKernel<paddle::DataType::FLOAT16>(
meta_data,
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_cpu,
max_enc_len_this_time,
max_dec_len_this_time,
max_len_kv,
attn_mask,
query_bias,
query_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
cache_quant_type_str,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
speculate_max_draft_token_num,
causal,
speculate_decoder);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& query_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
const std::vector<int64_t>& max_enc_len_this_time_shape,
const std::vector<int64_t>& max_dec_len_this_time_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& query_bias_shape,
const paddle::optional<std::vector<int64_t>>& query_out_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
const int token_num = query_shape[0];
const int kv_num_heads = key_cache_shape[1];
const int head_dim_qk = key_cache_shape[3];
const int head_dim_v = nope_size;
const int q_hidden_size = query_shape[query_shape.size() - 1];
const int num_heads = q_hidden_size / head_dim_qk;
return {{token_num, num_heads * head_dim_v}};
}
std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& query_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& decoder_num_blocks_cpu_dtype,
const paddle::DataType& max_enc_len_this_time_dtype,
const paddle::DataType& max_dec_len_this_time_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& query_bias_dtype,
const paddle::optional<paddle::DataType>& query_out_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
if (compute_dtype == "bf16") {
return {paddle::DataType::BFLOAT16};
} else if (compute_dtype == "fp16") {
return {paddle::DataType::FLOAT16};
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
}
}
PD_BUILD_OP(multi_head_latent_attention)
.Inputs({"query",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seqlens_q",
"padding_offsets",
"cum_offsets",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"decoder_num_blocks_cpu",
"max_enc_len_this_time",
"max_dec_len_this_time",
"max_len_kv",
paddle::Optional("attn_mask"),
paddle::Optional("query_bias"),
paddle::Optional("query_out_scales"),
paddle::Optional("cache_k_quant_scales"),
paddle::Optional("cache_v_quant_scales"),
paddle::Optional("cache_k_dequant_scales"),
paddle::Optional("cache_v_dequant_scales"),
paddle::Optional("cache_k_zp"),
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths")})
.Outputs({"fmha_out"})
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"nope_size: int",
"max_input_length: int",
"softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(MultiHeadLatentAttention))
.SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype));

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <optional>
#include "helper.h"
#include "noauxtc_kernel.h"
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor) {
auto input_shape = scores_with_bias.shape();
int64_t num_tokens = input_shape[0];
int64_t num_experts = input_shape[1];
auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto stream = scores_with_bias.stream();
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens,
num_experts,
n_group,
topk_group,
topk,
routed_scaling_factor,
stream);
return {scores};
}
std::vector<paddle::DataType> NoauxTcInferDtype(
const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype};
}
std::vector<std::vector<int64_t>> NoauxTcInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& gating_output_shape) {
return {scores_shape};
}
PD_BUILD_OP(noaux_tc)
.Inputs({"scores", "scores_with_bias"})
.Outputs({"output_tensor"})
.Attrs({"n_group: int",
"topk_group: int",
"topk:int",
"routed_scaling_factor: float"})
.SetKernelFn(PD_KERNEL(NoauxTc))
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype));

View File

@@ -0,0 +1,551 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This code is partially inspired by and references the implementation found
// in NVIDIA TRTLLM.
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
namespace warp_topk {
template <int size, typename T>
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
if (len == 0) {
return 0;
}
return ((len - 1) / size + 1) * size;
}
template <typename T>
constexpr __host__ __device__ bool isPowerOf2(T v) {
return (v && !(v & (v - 1)));
}
template <bool greater, typename T>
__device__ bool is_better_than(T val, T baseline) {
return (val > baseline && greater) || (val < baseline && !greater);
}
template <typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
return max(cache_topk,
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}
template <int size, bool ascending, typename T, typename idxT>
struct BitonicMerge {
// input should be a bitonic sequence, and sort it to be a monotonic sequence
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
static_assert(isPowerOf2(size));
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
constexpr int stride = arr_len / 2;
for (int i = 0; i < stride; ++i) {
int const other_i = i + stride;
T& val = val_arr[i];
T& other_val = val_arr[other_i];
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
T tmp = val;
val = other_val;
other_val = tmp;
idxT tmp2 = idx_arr[i];
idx_arr[i] = idx_arr[other_i];
idx_arr[other_i] = tmp2;
}
}
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
idx_arr + arr_len / 2);
}
};
template <int size, bool ascending, typename T, typename idxT>
struct BitonicSort {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
static_assert(isPowerOf2(size));
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
idx_arr + arr_len / 2);
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
}
};
template <bool ascending, typename T, typename idxT>
struct BitonicSort<32, ascending, T, idxT> {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
// ascending doesn't matter before merging since all we need is a bitonic
// sequence
for (int stage = 0; stage < 4; ++stage) {
for (int stride = (1 << stage); stride > 0; stride /= 2) {
bool reverse = (lane >> stage) & 2;
bool is_second = lane & stride;
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
*val_arr = other;
*idx_arr = other_idx;
}
}
}
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
}
};
template <bool ascending, typename T, typename idxT>
struct BitonicMerge<32, ascending, T, idxT> {
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
bool is_second = lane & stride;
T& val = *val_arr;
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
idxT& idx = *idx_arr;
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
if (val != other && ((val > other) == (ascending != is_second))) {
val = other;
idx = other_idx;
}
}
}
};
template <int capacity, bool greater, typename T, typename idxT>
class WarpSort {
public:
__device__ WarpSort(idxT k, T dummy)
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
for (int i = 0; i < max_arr_len_; ++i) {
val_arr_[i] = dummy_;
}
}
// load and merge k sorted values
__device__ void load_sorted(T const* __restrict__ in,
idxT const* __restrict__ in_idx,
idxT start) {
idxT idx = start + WARP_SIZE - 1 - lane_;
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
if (idx < start + k_) {
T t = in[idx];
if (is_better_than<greater>(t, val_arr_[i])) {
val_arr_[i] = t;
idx_arr_[i] = in_idx[idx];
}
}
}
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
}
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
for (int i = 0; i < max_arr_len_; ++i) {
idxT out_i = i * WARP_SIZE + lane_;
if (out_i < k_) {
out[out_i] = val_arr_[i];
out_idx[out_i] = idx_arr_[i];
}
}
}
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
for (int i = 0; i < max_arr_len_; ++i) {
idxT out_i = i * WARP_SIZE + lane_;
if (out_i < k_) {
out_idx[out_i] = idx_arr_[i];
}
}
}
protected:
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
T val_arr_[max_arr_len_];
idxT idx_arr_[max_arr_len_];
int const lane_;
idxT const k_;
T const dummy_;
}; // end class WarpSort
template <int capacity, bool greater, typename T, typename idxT>
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
public:
__device__ WarpSelect(idxT k, T dummy)
: WarpSort<capacity, greater, T, idxT>(k, dummy),
k_th_(dummy),
k_th_lane_((k - 1) % WARP_SIZE) {
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
int const num_of_warp = blockDim.x / WARP_SIZE;
int const warp_id = threadIdx.x / WARP_SIZE;
val_smem_ = reinterpret_cast<T*>(smem_buf);
val_smem_ += warp_id * WARP_SIZE;
idx_smem_ = reinterpret_cast<idxT*>(
smem_buf +
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
idx_smem_ += warp_id * WARP_SIZE;
}
__device__ void add(T const* in, idxT start, idxT end) {
idxT const end_for_fullwarp =
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
T val = (i < end) ? in[i] : dummy_;
add(val, i);
}
}
__device__ void add(T val, idxT idx) {
bool do_add = is_better_than<greater>(val, k_th_);
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
if (mask == 0) {
return;
}
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
if (do_add && pos < WARP_SIZE) {
val_smem_[pos] = val;
idx_smem_[pos] = idx;
do_add = false;
}
smem_buf_len_ += __popc(mask);
if (smem_buf_len_ >= WARP_SIZE) {
__syncwarp();
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
smem_buf_len_ -= WARP_SIZE;
}
if (do_add) {
pos -= WARP_SIZE;
val_smem_[pos] = val;
idx_smem_[pos] = idx;
}
__syncwarp();
}
__device__ void done() {
if (smem_buf_len_) {
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
merge_buf_(val, idx);
}
// after done(), smem is used for merging results among warps
__syncthreads();
}
private:
__device__ void set_k_th_() {
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
}
__device__ void merge_buf_(T val, idxT idx) {
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
T& old = val_arr_[max_arr_len_ - 1];
if (is_better_than<greater>(val, old)) {
old = val;
idx_arr_[max_arr_len_ - 1] = idx;
}
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
set_k_th_();
}
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
using WarpSort<capacity, greater, T, idxT>::val_arr_;
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
using WarpSort<capacity, greater, T, idxT>::lane_;
using WarpSort<capacity, greater, T, idxT>::k_;
using WarpSort<capacity, greater, T, idxT>::dummy_;
T* val_smem_;
idxT* idx_smem_;
int smem_buf_len_ = 0;
T k_th_;
int const k_th_lane_;
}; // end class WarpSelect
} // namespace warp_topk
template <typename T>
__device__ void topk_with_k2(T* output,
T const* input,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
// Get the top2 per thread
T largest = cuda::std::numeric_limits<T>::min();
T second_largest = cuda::std::numeric_limits<T>::min();
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i];
if (value > largest) {
second_largest = largest;
largest = value;
} else if (value > second_largest) {
second_largest = value;
}
}
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
largest = input[i];
}
}
__syncwarp(); // Ensure all threads have valid data before reduction
// Get the top2 warpwise
T max1 = cg::reduce(tile, largest, cg::greater<T>());
T max2 = max1;
bool equal_to_max1 = (max1 == largest);
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
if (count_max1 == 1) {
largest = (largest == max1) ? second_largest : largest;
max2 = cg::reduce(tile, largest, cg::greater<T>());
}
if (lane_id == 0) {
*output = max1 + max2;
}
}
template <typename T>
__global__ void topk_with_k2_kernel(T* output,
T* input,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
if (case_id < num_cases) {
input += case_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
}
}
template <typename T>
__global__ void group_idx_and_topk_idx_kernel(
T* scores,
T const* group_scores,
T* scores_with_bias,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
int64_t const num_experts,
int64_t const num_experts_per_group,
double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;
T value = cuda::std::numeric_limits<T>::min();
T topk_group_value = cuda::std::numeric_limits<T>::min();
int32_t num_equalto_topkth_group;
if ((n_group > topk_group) && (case_id < num_tokens)) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group) {
value = group_scores[lane_id];
}
int count_equal_to_top_value = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = cuda::std::numeric_limits<T>::min();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
int count_equalto_topkth_group = 0;
if (case_id < num_tokens) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = i < num_experts_per_group
? scores_with_bias[offset + i]
: cuda::std::numeric_limits<T>::min();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
count_equalto_topkth_group++;
}
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}
// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = i < topk ? scores[s_topk_idx[i]]
: 0.0f; // Load the valid value of expert
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, value, cg::plus<float>());
}
}
__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__threadfence();
__syncthreads();
if (case_id < num_tokens) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
scores[s_topk_idx[i]] = value;
}
}
}
template <typename T>
void invokeNoAuxTc(T* scores,
T* group_scores,
T* scores_with_bias,
int64_t const num_tokens,
int64_t const num_experts,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
double const routed_scaling_factor,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores,
scores_with_bias,
num_tokens,
num_cases,
n_group,
num_experts / n_group);
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
routed_scaling_factor);
}
#define INSTANTIATE_NOAUX_TC(T) \
template void invokeNoAuxTc<T>(T * scores, \
T * group_scores, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
double const routed_scaling_factor, \
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float);

View File

@@ -50,11 +50,13 @@ __global__ void quant_per_token_per_block(const T *input,
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
}
// get max value per warp
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(max_value_thread, epsilon);
float scale_to_store = max_value_thread / MAX_VALUE;
// quant

View File

@@ -267,6 +267,9 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
]
# pd_disaggregation
@@ -376,6 +379,8 @@ elif paddle.is_compiled_with_cuda():
# append_attention
sources += ["gpu_ops/append_attention.cu"]
sources += find_end_files("gpu_ops/append_attn", ".cu")
# mla
sources += ["gpu_ops/multi_head_latent_attention.cu"]
# gemm_dequant
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
# speculate_decoding
@@ -441,6 +446,10 @@ elif paddle.is_compiled_with_cuda():
sources += find_end_files(fp8_auto_gen_directory, ".cu")
if cc >= 90 and nvcc_version >= 12.0:
# Hopper optmized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(

View File

@@ -9,11 +9,6 @@ COPY . /workspace/FastDeploy
RUN echo "ulimit -u unlimited" >> /root/.bashrc
RUN echo "ulimit -n 65536" >> /root/.bashrc
# setting proxy
ARG http_proxy=agent.baidu.com:8891
ARG https_proxy=agent.baidu.com:8891
ARG no_proxy=localhost,bj.bcebos.com,su.bcebos.com,pypi.tuna.tsinghua.edu.cn,paddle-ci.gz.bcebos.com
# uninstall existing package
RUN python -m pip uninstall paddlepaddle-gpu fastdeploy-gpu -y

View File

@@ -7,11 +7,6 @@ deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse \n\
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse" > /etc/apt/sources.list
# setting proxy
ENV http_proxy=http://agent.baidu.com:8891
ENV https_proxy=http://agent.baidu.com:8891
ENV no_proxy=localhost,bj.bcebos.com,su.bcebos.com,pypi.tuna.tsinghua.edu.cn,paddle-ci.gz.bcebos.com
RUN apt-get update && apt-get install -y libibverbs-dev librdmacm-dev cmake pybind11-dev
# uninstall existing package
@@ -40,4 +35,4 @@ COPY . /workspace/FastDeploy
RUN cd /workspace/FastDeploy && bash build.sh && python -m pip install --no-cache-dir dist/* && rm -rf /workspace/FastDeploy
ENV http_proxy=""
ENV https_proxy=""
ENV https_proxy=""

View File

@@ -1,19 +1,19 @@
# Chain-of-Thought Content
# Reasoning Outputs
The reasoning model returns a `reasoning_content` field in the output, representing the chain-of-thought content—the reasoning steps that lead to the final conclusion.
Reasoning models return an additional `reasoning_content` field in their output, which contains the reasoning steps that led to the final conclusion.
## Currently Supported Chain-of-Thought Models
| Model Name | Parser Name | Chain-of-Thought Enabled by Default |
|----------------|----------------|-------------------------------------|
| ernie-45-vl | ernie-45-vl | ✓ |
| ernie-lite-vl | ernie-45-vl | ✓ |
## Supported Models
| Model Name | Parser Name | Eable_thinking by Default |
|----------------|----------------|---------------------------|
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | ernie-45-vl | ✓ |
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✓ |
The reasoning model requires a specified parser to interpret the reasoning content. The reasoning mode can be disabled by setting the `enable_thinking=False` parameter.
The reasoning model requires a specified parser to extract reasoning content. The reasoning mode can be disabled by setting the `enable_thinking=False` parameter.
Interfaces that support toggling the reasoning mode:
1. `/v1/chat/completions` request in OpenAI services.
2. `/v1/chat/completions` request in the OpenAI Python client.
3. `llm.chat` request in Offline interfaces.
1. `/v1/chat/completions` requests in OpenAI services.
2. `/v1/chat/completions` requests in the OpenAI Python client.
3. `llm.chat` requests in Offline interfaces.
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
@@ -21,10 +21,15 @@ For reasoning models, the length of the reasoning content can be controlled via
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
This parser will process the model's output and extract the `reasoning_content` field.
```bash
python -m fastdeploy.entrypoints.openai.api_server --model /root/merge_llm_model --enable-mm --tensor-parallel-size=8 --port 8192 --quantization wint4 --reasoning-parser=ernie-45-vl
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \
--enable-mm \
--tensor-parallel-size 8 \
--port 8192 \
--quantization wint4 \
--reasoning-parser ernie-45-vl
```
Next, send a `chat completion` request to the model:
Next, make a request to the model that should return the reasoning content in the response.
```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \
@@ -40,8 +45,8 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
```
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
### Streaming Sessions
In streaming sessions, the `reasoning_content` field can be retrieved from the `delta` in `chat completion response chunks`.
### Streaming chat completions
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
```python
from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server.

View File

@@ -62,7 +62,7 @@ The differences in request parameters between FastDeploy and the OpenAI protocol
- `stream_options`: Optional[StreamOptions] = None
- `temperature`: Optional[float] = None
- `top_p`: Optional[float] = None
- `metadata`: Optional[dict] = None (supported only in `v1/chat/completions` for configuring additional parameters, e.g., `meta_data={"enable_thinking": True}`)
- `metadata`: Optional[dict] = None (supported only in `v1/chat/completions` for configuring additional parameters, e.g., `metadata={"enable_thinking": True}`)
- `min_tokens`: Optional[int] = 1 (minimum number of tokens generated)
- `reasoning_max_tokens`: Optional[int] = None (maximum number of tokens for reasoning content, defaults to the same as `max_tokens`)
- `enable_thinking`: Optional[bool] = True (whether to enable reasoning for models that support deep thinking)

View File

@@ -27,7 +27,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```kv_cache_ratio``` | `float` | KVCache blocks are divided between Prefill phase and Decode phase according to kv_cache_ratio ratio, default: 0.75 |
| ```enable_prefix_caching``` | `bool` | Whether to enable Prefix Caching, default: False |
| ```swap_space``` | `float` | When Prefix Caching is enabled, CPU memory size for KVCache swapping, unit: GB, default: None |
| ```enable_chunk_prefill``` | `bool` | Enable Chunked Prefill, default: False |
| ```enable_chunked_prefill``` | `bool` | Enable Chunked Prefill, default: False |
| ```max_num_partial_prefills``` | `int` | When Chunked Prefill is enabled, maximum concurrent number of partial prefill batches, default: 1 |
| ```max_long_partial_prefills``` | `int` | When Chunked Prefill is enabled, maximum number of long requests in concurrent partial prefill batches, default: 1 |
| ```long_prefill_token_threshold``` | `int` | When Chunked Prefill is enabled, requests with token count exceeding this value are considered long requests, default: max_model_len*0.04 |
@@ -115,5 +115,5 @@ FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter
...
```
- When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1.
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunk_prefill```.
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```.
- When ```use_cudagraph``` is enabled, batches with size ≤ ```max_capture_batch_size``` will be executed by CudaGraph, batches > ```max_capture_batch_size``` will be executed by original dynamic/static graph. To have all batch sizes executed by CudaGraph, ```max_capture_batch_size``` value should match ```max_num_seqs```. ```max_capture_batch_size``` > ```max_num_seqs``` will cause waste by capturing batches that won't be encountered during inference, occupying more time and memory.

View File

@@ -57,3 +57,6 @@ On the ERNIE-4.5-300B-A47B model, comparison of WINT2 vs WINT4 performance:
| IFEval |500|88.17 | 85.40 |
|BBH|6511|94.43|92.02|
|DROP|9536|91.17|89.97|
|GSM8K|1319|96.21|95.98|
|CMath|600|96.50|96.00|
|CMMLU|11477|89.92|86.22|

View File

@@ -5,8 +5,8 @@
##目前支持思考链的模型
| 模型名称 | 解析器名称 | 默认开启思考链 |
|---------------|-------------|---------|
| ernie-45-vl | ernie-45-vl | ✓ |
| ernie-lite-vl | ernie-45-vl | ✓ |
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | ernie-45-vl | ✓ |
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✓ |
思考模型需要指定解析器,以便于对思考内容进行解析. 通过`enable_thinking=False` 参数可以关闭模型思考模式.

View File

@@ -61,7 +61,7 @@ FastDeploy 与 OpenAI 协议的请求参数差异如下,其余请求参数会
- `stream_options`: Optional[StreamOptions] = None
- `temperature`: Optional[float] = None
- `top_p`: Optional[float] = None
- `metadata`: Optional[dict] = None (仅在v1/chat/compeltions中支持用于配置额外参数, 如meta_data={"enable_thinking": True})
- `metadata`: Optional[dict] = None (仅在v1/chat/compeltions中支持用于配置额外参数, 如metadata={"enable_thinking": True})
- `min_tokens`: Optional[int] = 1 最小生成的Token个数
- `reasoning_max_tokens`: Optional[int] = None 思考内容最大Token数默认与max_tokens一致
- `enable_thinking`: Optional[bool] = True 支持深度思考的模型是否打开思考

View File

@@ -26,7 +26,7 @@
| ```kv_cache_ratio``` | `float` | KVCache块按kv_cache_ratio比例分给Prefill阶段和Decode阶段, 默认0.75 |
| ```enable_prefix_caching``` | `bool` | 是否开启Prefix Caching默认False |
| ```swap_space``` | `float` | 开启Prefix Caching时用于swap KVCache的CPU内存大小单位GB默认None |
| ```enable_chunk_prefill``` | `bool` | 开启Chunked Prefill默认False |
| ```enable_chunked_prefill``` | `bool` | 开启Chunked Prefill默认False |
| ```max_num_partial_prefills``` | `int` | 开启Chunked Prefill时Prefill阶段的最大并发数默认1 |
| ```max_long_partial_prefills``` | `int` | 开启Chunked Prefill时Prefill阶段并发中包启的最多长请求数默认1 |
| ```long_prefill_token_threshold``` | `int` | 开启Chunked Prefill时请求Token数超过此值的请求被视为长请求默认为max_model_len*0.04 |
@@ -113,5 +113,5 @@ FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算
...
```
- 当开启 ```use_cudagraph``` 时,暂时只支持单卡推理,即 ```tensor_parallel_size``` 设为1。
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunk_prefill``` 。
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunked_prefill``` 。
- 当开启 ```use_cudagraph``` 后size小于等于 ```max_capture_batch_size``` 的batch会由CudaGraph来执行前向计算大于 ```max_capture_batch_size``` 的batch会由原本的动态图/静态图执行前向计算。如果希望所有batch size均由CudaGraph来执行```max_capture_batch_size``` 的值建议与 ```max_num_seqs``` 一致。```max_capture_batch_size``` 大于 ```max_num_seqs``` 会导致浪费会多捕获一些推理时不会遇到的batch占用更多时间与显存。

View File

@@ -15,11 +15,14 @@
"""
import os
import subprocess
import sys
# suppress warning log from paddlepaddle
os.environ["GLOG_minloglevel"] = "2"
# suppress log from aistudio
os.environ["AISTUDIO_LOG"] = "critical"
from fastdeploy.utils import version
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
@@ -30,3 +33,48 @@ try:
use_triton_in_paddle.make_triton_compatible_with_paddle()
except ImportError:
pass
# TODO(tangbinhan): remove this code
def _patch_fastsafetensors():
try:
file_path = subprocess.check_output([
sys.executable, "-c", "import fastsafetensors, os; \
print(os.path.join(os.path.dirname(fastsafetensors.__file__), \
'frameworks', '_paddle.py'))"
]).decode().strip()
with open(file_path, 'r') as f:
content = f.read()
if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content:
return
modified = False
if "DType.U16: DType.BF16," not in content:
lines = content.splitlines()
new_lines = []
inside_block = False
for line in lines:
new_lines.append(line)
if 'need_workaround_dtypes: Dict[DType, DType] = {' in line:
inside_block = True
elif inside_block and '}' in line:
new_lines.insert(-1, ' DType.U16: DType.BF16,')
inside_block = False
modified = True
content = "\n".join(new_lines)
if "DType.I8: paddle.uint8," in content:
content = content.replace("DType.I8: paddle.uint8,",
"DType.U8: paddle.uint8,")
modified = True
if modified:
with open(file_path, 'w') as f:
f.write(content + "\n")
except Exception as e:
print(f"Failed to patch fastsafetensors: {e}")
_patch_fastsafetensors()

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -51,7 +51,6 @@ class ModelConfig(PretrainedConfig):
top_p = 0.0
temperature = 1.0
rope_theta = 10000.0
rope_scaling = None
penalty_score = 1.0
frequency_score = 0.0
presence_score = 0.0
@@ -70,7 +69,6 @@ class ModelConfig(PretrainedConfig):
max_seq_len: int = 512,
initializer_range: float = 0.02,
use_rope=True,
use_fast_ffn: bool = False,
rope_theta: int = 10000,
rope_3d: bool = False,
ori_vocab_size: int | None = None,
@@ -105,7 +103,6 @@ class ModelConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_rope = use_rope
self.use_fast_ffn = use_fast_ffn
self.rope_theta = rope_theta
self.ori_vocab_size = ori_vocab_size or vocab_size
self.max_seq_len = max_seq_len
@@ -142,6 +139,7 @@ class MoEConfig:
moe_num_shared_experts = (0, )
moe_layer_start_index = 0
moe_layer_end_index = None
moe_use_aux_free: bool = False
num_max_dispatch_tokens_per_rank = 256
im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
@@ -163,7 +161,6 @@ class ParallelConfig:
# The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column.
column_cut = False # (bool, optional)
"""
From old wersion worker args
TODO(gongshaotian): Reclassify
@@ -194,18 +191,13 @@ class ParallelConfig:
engine_pid: Optional[int] = None
# Do profile or not
do_profile: bool = False
# Dynamic load weight or not
dynamic_load_weight: bool = False
#
pad_token_id: int = -1
#
eos_tokens_lens: int = 2
# Enable chunked prefill
enable_chunked_prefill: str = "store_true"
"""
- APPEND_ATTN:
"""
attention_backend: str = "APPEND_ATTN"
max_num_batched_tokens: int = 2048
# enable prefix cache
enable_prefix_caching = None
@@ -354,9 +346,27 @@ class GraphOptimizationConfig:
@dataclass
class LoadConfig:
"""
Configuration for loading parameter
Configuration for dynamic weight loading strategies
Attributes:
dynamic_load_weight: Whether to enable dynamic weight loading
load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_no_reshard': Real-time IPC streaming without weight process
- 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': provide RL traing worker, no_weights_load
- None: No dynamic loading
"""
pass
use_fastsafetensor: bool = False
dynamic_load_weight: bool = False
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
def __post_init__(self):
if self.load_strategy is not None and not self.dynamic_load_weight:
raise ValueError("Load strategy requires dynamic_load_weight=True")
if self.dynamic_load_weight and self.load_strategy is None:
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
@dataclass
@@ -392,7 +402,7 @@ class FDConfig:
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
moe_config: MoEConfig = field(default=None, init=True) # type: ignore

View File

@@ -16,48 +16,54 @@
import time
import os
import subprocess
import signal
import multiprocessing
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.engine.sampling_params import SamplingParams
model_name_or_path = "./models/eb45t02/"
model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle"
prefill_cmd = (f"FD_LOG_DIR=log_prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py"
+ f" --model {model_name_or_path} --port 9811"
+ f" --splitwise-role prefill --tensor-parallel-size 4"
+ f" --engine-worker-queue-port 6676 --cache-queue-port 55663")
def start_decode(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["FD_LOG_DIR"] = "log_decode"
llm_decode = LLM(
model=model_name_or_path,
tensor_parallel_size=1,
splitwise_role="decode",
engine_worker_queue_port=6678,
innode_prefill_ports=[6676],
cache_queue_port=55668
)
return llm_decode
prefill_instance = subprocess.Popen(
prefill_cmd,
stdout=subprocess.PIPE,
shell=True,
preexec_fn=os.setsid,
)
def start_prefill(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["FD_LOG_DIR"] = "log_prefill"
llm_prefill = LLM(
model=model_name_or_path,
tensor_parallel_size=1,
splitwise_role="prefill",
engine_worker_queue_port=6677,
cache_queue_port=55667,
)
def main():
prefill = multiprocessing.Process(
target=start_prefill,
args=(model_name_or_path,)).start()
time.sleep(10)
llm_decode = start_decode(model_name_or_path)
output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True)
print(output)
decode.join()
# # 超参设置
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
os.environ["FD_LOG_DIR"] = "log_decode"
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm_decode = LLM(
model=model_name_or_path,
tensor_parallel_size=4,
splitwise_role="decode",
engine_worker_queue_port=6678,
innode_prefill_ports=[6676],
cache_queue_port=55668
)
output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True)
print(output)
os.killpg(prefill_instance.pid, signal.SIGTERM)
if __name__ == "__main__":
main()

View File

@@ -17,13 +17,15 @@
import paddle
import paddle.distributed as dist
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
if paddle.in_dynamic_mode():
hcg = dist.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
try:
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
if paddle.in_dynamic_mode():
hcg = dist.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
except:
tensor_model_parallel_all_reduce=None

View File

@@ -87,10 +87,14 @@ class EngineArgs:
"""
Configuration for speculative execution.
"""
dynamic_load_weight: int = 0
dynamic_load_weight: bool = False
"""
dynamic load weight
"""
load_strategy: str = "meta"
"""
dynamic load weight strategy
"""
quantization: str = None
guided_decoding_backend: str = "off"
"""
@@ -364,13 +368,16 @@ class EngineArgs:
type=json.loads,
default=EngineArgs.speculative_config,
help="Configuration for speculative execution.")
model_group.add_argument(
"--dynamic-load-weight",
type=int,
action='store_true',
default=EngineArgs.dynamic_load_weight,
help="Flag to indicate whether to load weight dynamically.")
model_group.add_argument(
"--load-strategy",
type=str,
default=EngineArgs.load_strategy,
help="Flag to dynamic load strategy.")
model_group.add_argument("--engine-worker-queue-port",
type=int,
default=EngineArgs.engine_worker_queue_port,
@@ -383,6 +390,7 @@ class EngineArgs:
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
model_group.add_argument(
"--enable-static-graph-inference",
action='store_true',
@@ -668,8 +676,9 @@ class EngineArgs:
"""
return ModelConfig(model_name_or_path=self.model,
config_json_file=self.model_config_name,
quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight,
quantization=self.quantization)
load_strategy=self.load_strategy)
def create_cache_config(self, model_cfg) -> CacheConfig:
"""
@@ -749,6 +758,9 @@ class EngineArgs:
speculative_cfg = self.create_speculative_config()
assert not (self.use_cudagraph and self.enable_prefix_caching), \
"Prefix caching cannot be used with CUDA graph"
return Config(
model_name_or_path=self.model,
model_config=model_cfg,

View File

@@ -17,6 +17,7 @@
import json
import os
from datetime import datetime
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs
@@ -41,7 +42,8 @@ class ModelConfig:
def __init__(self,
model_name_or_path: str,
config_json_file: str = "config.json",
dynamic_load_weight: int = 0,
dynamic_load_weight: bool = False,
load_strategy: str="meta",
quantization: str = None,
download_dir: Optional[str] = None):
"""
@@ -55,6 +57,7 @@ class ModelConfig:
self.model_dir = model_name_or_path
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
self.dynamic_load_weight = dynamic_load_weight
self.load_strategy = load_strategy
self.quantization = quantization
config_file = os.path.join(model_name_or_path, config_json_file)
@@ -465,7 +468,63 @@ class ParallelConfig:
llm_logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("==================")
llm_logger.info(
"=============================================================")
@dataclass
class CommitConfig:
"""
Configuration for tracking version information from version.txt
Attributes:
fastdeploy_commit: Full FastDeploy git commit hash
paddle_version: PaddlePaddle version string
paddle_commit: PaddlePaddle git commit hash
cuda_version: CUDA version string
compiler_version: CXX compiler version string
"""
fastdeploy_commit: str = ""
paddle_version: str = ""
paddle_commit: str = ""
cuda_version: str = ""
compiler_version: str = ""
def __post_init__(self):
"""Automatically load version info when initialized"""
self._load_from_version_file()
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
"""Internal method to load version info from file"""
try:
with open(file_path, 'r') as f:
for line in f:
line = line.strip()
if line.startswith("fastdeploy GIT COMMIT ID:"):
self.fastdeploy_commit = line.split(":")[1].strip()
elif line.startswith("Paddle version:"):
self.paddle_version = line.split(":")[1].strip()
elif line.startswith("Paddle GIT COMMIT ID:"):
self.paddle_commit = line.split(":")[1].strip()
elif line.startswith("CUDA version:"):
self.cuda_version = line.split(":")[1].strip()
elif line.startswith("CXX compiler version:"):
self.compiler_version = line.split(":")[1].strip()
except FileNotFoundError:
llm_logger.info(f"Warning: Version file not found at {file_path}")
except Exception as e:
llm_logger.info(f"Warning: Could not read version file - {str(e)}")
def print(self):
"""
print all config
"""
llm_logger.info("Fasedeploy Commit Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
class Config:
@@ -500,6 +559,7 @@ class Config:
cache_config: CacheConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
commit_config: CommitConfig = CommitConfig(),
model_name_or_path: str = None,
tokenizer: str = None,
tensor_parallel_size: int = 8,
@@ -559,6 +619,7 @@ class Config:
self.cache_config = cache_config
self.scheduler_config = scheduler_config
self.parallel_config = parallel_config
self.commit_config = commit_config
self.model_name_or_path = model_name_or_path
self.tokenizer = tokenizer
self.max_num_batched_tokens = max_num_batched_tokens
@@ -584,12 +645,10 @@ class Config:
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
if self.innode_prefill_ports is not None:
if not isinstance(self.innode_prefill_ports, list):
ports = str(self.innode_prefill_ports).split(',')
self.innode_prefill_ports = [int(port) for port in ports]
assert self.splitwise_role in ["mixed", "prefill", "decode"]
@@ -728,7 +787,7 @@ class Config:
), "XPU currently do not support guided_decoding"
try:
pass
import xgrammar # noqa
except Exception as e:
raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
@@ -749,7 +808,11 @@ class Config:
if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items():
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
elif (k == "cache_config" or
k == "model_config" or
k == "scheduler_config" or
k == "parallel_config" or
k == "commit_config"):
v.print()
else:
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))

View File

@@ -286,6 +286,8 @@ class LLMEngine(object):
while self.running:
try:
results = self.scheduler.get_results()
if len(results) == 0:
time.sleep(0.001)
for request_id, contents in results.items():
for result in contents:
self.zmq_server.send_multipart(request_id, result)
@@ -444,8 +446,8 @@ class LLMEngine(object):
enable_thinking = None
if kwargs is not None:
enable_thinking = kwargs.get("enable_thinking", None)
request = self.data_processor.process_request(request,
self.cfg.max_model_len, enable_thinking=enable_thinking)
request = self.data_processor.process_request(
request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request.prompt_token_ids_len = len(request.prompt_token_ids)
input_ids_len = request.prompt_token_ids_len
request.set(
@@ -453,7 +455,8 @@ class LLMEngine(object):
min(self.cfg.max_model_len - input_ids_len,
request.get("max_tokens")))
if request.get("reasoning_max_tokens") is None:
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
default_reasoning_max_tokens = max(
int(request.get("max_tokens") * 0.8), 1)
request.set("reasoning_max_tokens", default_reasoning_max_tokens)
min_tokens = request.get("min_tokens")
if input_ids_len + min_tokens >= self.cfg.max_model_len:
@@ -963,8 +966,8 @@ class LLMEngine(object):
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring",
"FLAGS_hardamard_moe_block_size": 128,
"FLAGS_max_partition_size": 32768,
"FLAGS_hardamard_moe_block_size": 128,
}
# environment variables needed by Dy2St
variables.update({
@@ -1017,6 +1020,12 @@ class LLMEngine(object):
worker_path = "../worker/vl_worker_process.py"
py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
else len(self.data_processor.tokenizer.vocab)
)
arguments = (
f" --nnodes {str(self.cfg.nnode)}"
f" --devices {self.cfg.device_ids} {py_script}"
@@ -1037,13 +1046,14 @@ class LLMEngine(object):
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --quantization {self.cfg.model_config.quantization}"
f" --ori_vocab_size {len(self.data_processor.tokenizer)}"
f" --ori_vocab_size {ori_vocab_size}"
f" --speculative_method {self.cfg.speculative_config.method}"
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}")
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}")
worker_append_flag = {
"enable_expert_parallel":
@@ -1188,8 +1198,9 @@ class LLMEngine(object):
line = line.decode('utf-8', errors='ignore')
if self.worker_init_status.get("finished", False):
break
if match := re.search(r'Loading checkpoint shards:\s*(\d+)',
line):
if match := re.search(
r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)',
line):
self.worker_init_status["weight_loadding"] = eval(
match.group(1)) * 1.0 / 100
elif (match := re.search(r'Start load layer (\d+)',

View File

@@ -122,7 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
"""
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
class ChatCompletionResponse(BaseModel):

View File

@@ -220,6 +220,9 @@ class OpenAIServingChat:
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
choice.finish_reason = "recover_stop"
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
choice.delta.token_ids = output["token_ids"]
@@ -335,6 +338,9 @@ class OpenAIServingChat:
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
choice.finish_reason = "recover_stop"
choices.append(choice)
num_prompt_tokens = len(prompt_token_ids)

View File

@@ -82,13 +82,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_MOE_BACKEND":
lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Set whether to disable recompute the request when the KV cache is full.
"FD_DISABLED_RECOVER":
lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
# Set triton kernel JIT compilation directory.
"FD_TRITON_KERNEL_CACHE_DIR":
lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
# Whether transition from standalone PD decoupling to centralized inference
"FD_PD_CHANGEABLE":
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
# Whether to use fastsafetensor load weight (0 or 1)
"FD_USE_FASTSAFETENSOR":
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
}

View File

@@ -27,6 +27,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor
_SAMPLING_EPS = 1e-5
class ErnieProcessor(BaseDataProcessor):
"""
初始化模型实例。
@@ -160,6 +161,7 @@ class ErnieProcessor(BaseDataProcessor):
if request.get('prompt'):
prompt = request.get('prompt')
prompt = prompt[0] if isinstance(prompt, list) else prompt
tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request['prompt_token_ids'] = token_ids

View File

@@ -82,6 +82,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
# pre-process map-type all spec token for decode accelerate.
@property
def space_token(self):
@@ -136,14 +137,19 @@ class ErnieBotTokenizer(PretrainedTokenizer):
"""doc"""
return self.sp_model.id_to_piece(id)
def spec_init(self):
if not hasattr(self, "all_spec_tok"):
self.all_spec_tok = set(self.all_special_tokens)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
self.spec_init()
current_sub_tokens = []
out_string = ""
# prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if token in self.all_spec_tok:
# if not prev_is_special:
# out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
@@ -210,13 +216,14 @@ class ErnieBotTokenizer(PretrainedTokenizer):
# if isinstance(t, AddedToken)
# )
self.spec_init()
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
# TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)

View File

@@ -25,6 +25,7 @@ from fastdeploy.utils import data_processor_logger
_SAMPLING_EPS = 1e-5
class BaseDataProcessor(ABC):
"""base class for data processor"""

View File

@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .attention import Attention
from .append_attn_backend import AppendAttentionBackend
from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend
__all__ = [
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend"
"AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend", "FlashAttentionBackend"
]

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta
@@ -187,6 +187,8 @@ class AppendAttentionBackend(AttentionBackend):
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:

View File

@@ -111,6 +111,8 @@ class Attention(nn.Layer):
k: paddle.Tensor = None,
v: paddle.Tensor = None,
qkv: paddle.Tensor = None,
compressed_kv: paddle.Tensor = None,
k_pe: paddle.Tensor = None,
forward_meta: ForwardMeta = None,
) -> paddle.Tensor:
"""
@@ -120,12 +122,16 @@ class Attention(nn.Layer):
k: the key tensor
v: the value tensor
forward_meta: the forward meta data
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
"""
return forward_meta.attn_backend.forward(
q,
k,
v,
qkv,
compressed_kv,
k_pe,
self,
forward_meta,
)

View File

@@ -16,6 +16,7 @@
from functools import cache
from fastdeploy import envs
from fastdeploy.platforms import _Backend, current_platform
from fastdeploy.utils import resolve_obj_from_strname
@@ -40,6 +41,7 @@ def _get_attn_backend(selected_backend: str) -> object:
return resolve_obj_from_strname(attention_cls)
def get_attention_backend(selected_backend):
"""Selects which attention backend ."""
return _get_attn_backend(selected_backend)
def get_attention_backend() -> object:
"""Selects which attention backend."""
attention_backend = envs.FD_ATTENTION_BACKEND
return _get_attn_backend(attention_backend)

View File

@@ -46,6 +46,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
@@ -56,6 +58,8 @@ class AttentionBackend(ABC):
k: The key tensor.
v: The value tensor.
layer: The layer that will be used for the forward.
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
forward_meta: The forward metadata.
"""
if forward_meta.forward_mode.is_mixed():
@@ -64,6 +68,8 @@ class AttentionBackend(ABC):
k,
v,
qkv,
compressed_kv,
k_pe,
layer,
forward_meta,
)
@@ -73,6 +79,8 @@ class AttentionBackend(ABC):
k,
v,
qkv,
compressed_kv,
k_pe,
layer,
forward_meta,
)
@@ -82,6 +90,8 @@ class AttentionBackend(ABC):
k,
v,
qkv,
compressed_kv,
k_pe,
layer,
forward_meta,
)
@@ -92,6 +102,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
@@ -104,6 +116,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
@@ -116,6 +130,8 @@ class AttentionBackend(ABC):
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:

View File

@@ -0,0 +1,247 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import List, Optional
import paddle
try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except:
flash_attention_v3_varlen = None
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
from fastdeploy.worker.forward_meta import ForwardMeta
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
pre_cache_num_blocks_cpu = None
kv_token_num_cpu = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.max_seq_len = fd_config.parallel_config.max_model_len
self.causal = getattr(fd_config.model_config, "causal", True)
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim
self.hidden_size = fd_config.model_config.hidden_size
self.block_size = fd_config.parallel_config.block_size
self.num_layers: int = fd_config.model_config.num_layers
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank
if self.device_id is None:
self.device_id = device_id
else:
self.device_id = self.device_id.split(",")[device_id]
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.kv_num_heads, self.block_size,
self.head_dim)
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.set_max_lengths[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(
metadata.decoder_tile_ids_per_batch, False)
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0],
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
)
res = flash_attention_v3_varlen(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=metadata.set_max_lengths[0],
max_seqlen_k=metadata.set_max_lengths[3],
causal=self.causal,
)[0].reshape([-1, self.hidden_size])
return res

View File

@@ -0,0 +1,490 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple
import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block, init_signal_layerwise,
open_shm_and_get_meta_signal)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
multi_head_latent_attention,
prefill_mla_write_cache)
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
@dataclass
class MLAAttentionMetadata(AttentionMetadata):
"""
MLAAttentionMetadata for Multi-Layer Attention
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: _DTypeLiteral = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class MLAAttentionBackend(AttentionBackend):
"""
MLA Attention Backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int) -> None:
"""
MLAAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: MLAAttentionMetadata = None
# 基础配置
self.block_size: int = fd_config.parallel_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (10000.0
if fd_config.model_config.rope_theta is None
else fd_config.model_config.rope_theta)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_layers
# For Multi Head Latent Attention
self.kv_lora_rank: int = fd_config.model_config.deepseekv3.kv_lora_rank
self.qk_rope_head_dim: int = fd_config.model_config.deepseekv3.qk_rope_head_dim
self.qk_head_dim: int = fd_config.model_config.deepseekv3.qk_nope_head_dim \
+ fd_config.model_config.deepseekv3.qk_rope_head_dim
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
if fd_config.model_config.deepseekv3.rope_scaling:
mscale_all_dim = fd_config.model_config.deepseekv3.rope_scaling.get(
"mscale_all_dim", False) # 1.0
scaling_factor = fd_config.model_config.deepseekv3.rope_scaling[
"factor"] # 40
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
# pd_disaggregation
self.use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if self.device_id is None:
self.device_id = self.rank
else:
self.device_id = self.device_id.split(",")[self.rank]
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype()
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cum_offsets,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# MLA
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.use_pd_disaggregation:
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag)
self.attention_metadata: AttentionMetadata = metadata
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(self,
max_num_blocks: int) -> Tuple[int, int, int, int]:
"""
Calculate kv cache shape for MLA
"""
return (max_num_blocks, 1, self.block_size,
self.kv_lora_rank + self.qk_rope_head_dim)
def forward_extend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Prefill阶段的前向传播
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
# 写入缓存
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
getattr(forward_meta, 'max_input_length', -1),
)
# Flash注意力计算
fmha_out = flash_attn_unpadded(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
)[0]
return fmha_out
def forward_decode(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Decode阶段的前向传播
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
# 获取推测解码参数
speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num
# 写入缓存
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths
metadata._fuse_kernel_compute_dtype,
"none", # cache_quant_type
self.kv_lora_rank,
self.max_seq_len,
self.attn_softmax_scale,
0.0, # quant_max_bound
0.0, # quant_min_bound
0.0, # out_linear_in_scale
speculate_max_tokens,
True, # causal
speculate_decoder,
)
return fmha_out
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Mixed模式的前向传播
"""
metadata = self.attention_metadata
speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num
decode_stage = forward_meta.is_decode_batch
prefill_stage = not (forward_meta.is_decode_batch)
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index)
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
if prefill_stage:
# 写入缓存
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
self.max_seq_len,
)
# FA
fmha_out = flash_attn_unpadded(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
)[0]
return fmha_out
# Decode
if decode_stage:
# mla写入缓存
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.padding_offset,
forward_meta.cum_offsets,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths
metadata._fuse_kernel_compute_dtype,
"none", # cache_quant_type
self.kv_lora_rank,
self.max_seq_len,
self.attn_softmax_scale,
0.0, # quant_max_bound
0.0, # quant_min_bound
0.0, # out_linear_in_scale
speculate_max_tokens,
True, # causal
speculate_decoder,
)
return fmha_out

View File

@@ -17,10 +17,16 @@
from .append_attention import append_attention
from .get_block_shape_and_split_kv_block import \
get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_signal_layerwise import init_signal_layerwise
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
from .pre_cache_len_concat import pre_cache_len_concat
__all__ = [
"get_block_shape_and_split_kv_block", "append_attention",
"open_shm_and_get_meta_signal", "init_signal_layerwise"
"get_block_shape_and_split_kv_block",
"append_attention",
"open_shm_and_get_meta_signal",
"init_signal_layerwise",
"gqa_rope_write_cache",
"pre_cache_len_concat",
]

View File

@@ -0,0 +1,66 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.platforms import current_platform
def gqa_rope_write_cache(
qkv: paddle.Tensor,
key_cache: paddle.Tensor,
value_cache: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
rotary_embs: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
padding_offsets: paddle.Tensor,
cum_offsets: paddle.Tensor,
block_tables: paddle.Tensor,
kv_batch_ids: paddle.Tensor,
kv_tile_ids_per_batch: paddle.Tensor,
kv_num_blocks: paddle.Tensor,
cache_batch_ids: paddle.Tensor,
cache_tile_ids_per_batch: paddle.Tensor,
cache_num_blocks: paddle.Tensor,
cache_k_quant_scales: Optional[paddle.Tensor] = None,
cache_v_quant_scales: Optional[paddle.Tensor] = None,
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
cache_k_zp: Optional[paddle.Tensor] = None,
cache_v_zp: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None,
kv_token_num: int = 1,
max_seq_len: int = 0,
cache_quant_type: str = "none"):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
q, k, v, qkv_ = gqa_rope_write_cache(
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
rotary_embs, seq_lens_this_time, seq_lens_encoder,
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
kv_token_num, max_seq_len, cache_quant_type)
return q, k, v, qkv_
else:
raise NotImplementedError()

View File

@@ -0,0 +1,36 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License,
Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.platforms import current_platform
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
max_dec_len: int = 0,
block_size: int = 64):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
max_dec_len, block_size)
return out
else:
raise NotImplementedError()

View File

@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta
@@ -149,6 +149,8 @@ class XPUAttentionBackend(AttentionBackend):
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:

View File

@@ -41,16 +41,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
Create weights for linear layer on XPU
"""
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8"
linear_weight_scale_shape = [layer.embed_dim]
if hasattr(layer, "linear_weight_shape"):
if isinstance(layer.linear_weight_shape, list):
layer_weight_shape = layer.linear_weight_shape
linear_weight_scale_shape = layer_weight_shape[:1]
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
dtype="float32",

View File

@@ -14,10 +14,15 @@
# limitations under the License.
"""
from typing import Dict
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from .utils import get_tensor
@@ -28,12 +33,12 @@ class VocabParallelEmbedding(nn.Layer):
def __init__(
self,
fd_config,
num_embeddings,
embedding_dim=768,
params_dtype="bfloat16",
fd_config: FDConfig,
num_embeddings: int,
embedding_dim: int = 768,
params_dtype: str = "bfloat16",
prefix="",
):
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
@@ -41,28 +46,28 @@ class VocabParallelEmbedding(nn.Layer):
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
num_embeddings : vocabulary size.
embedding_dim : size of hidden state.
params_dtype : data type of parameters.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
num_embeddings (int) : vocabulary size.
embedding_dim (int) : size of hidden state.
params_dtype (str) : data type of parameters.
prefix (str): The name of current layer. Defaults to "".
"""
super().__init__()
self.fd_config = fd_config
hcg = fleet.get_hybrid_communicate_group()
self.mp_rank = hcg.get_model_parallel_rank()
self.column_cut = fd_config.parallel_config.column_cut
self.world_size = hcg.get_model_parallel_world_size()
self.ring_id = hcg.get_model_parallel_group().id
self.use_rope = fd_config.model_config.use_rope
self.rope_head_dim = fd_config.model_config.rope_head_dim
self.use_ep = fd_config.parallel_config.use_ep
self.hidden_dropout_prob = fd_config.model_config.hidden_dropout_prob
self.initializer_range = fd_config.model_config.initializer_range
self.sequence_parallel = fd_config.parallel_config.sequence_parallel
self.max_position_embeddings = fd_config.model_config.max_position_embeddings
self.freeze_embedding = fd_config.model_config.freeze_embedding
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
self.mp_rank: int = hcg.get_model_parallel_rank()
self.column_cut = False
self.world_size: int = hcg.get_model_parallel_world_size()
self.ring_id: int = hcg.get_model_parallel_group().id
self.use_rope: bool = fd_config.model_config.use_rope
self.rope_head_dim: int = fd_config.model_config.rope_head_dim
self.use_ep: bool = fd_config.parallel_config.use_ep
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.freeze_embedding: bool = fd_config.model_config.freeze_embedding
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
if self.use_ep:
self.word_embeddings = nn.Embedding(
@@ -109,7 +114,8 @@ class VocabParallelEmbedding(nn.Layer):
self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim),
dtype="int8")
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
@@ -125,7 +131,7 @@ class VocabParallelEmbedding(nn.Layer):
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
paddle.get_default_dtype()))
def forward(self, ids_remove_padding=None):
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.

View File

@@ -216,6 +216,14 @@ class ReplicatedLinear(LinearBase):
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant)
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.input_size,
self.output_size,
]
if fd_config.quant_config:
self.quant_method.create_weights(self)
self.init_weight()
@@ -259,7 +267,10 @@ class ColumnParallelLinear(LinearBase):
skip_quant=skip_quant)
self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.input_size = input_size
self.output_size = divide(output_size, self.nranks)
self.output_size = divide(
output_size,
self.nranks) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.input_size,
self.output_size,
@@ -282,7 +293,7 @@ class ColumnParallelLinear(LinearBase):
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_weight, split_axis=-1)
_set_var_distributed(self.linear_weight, split_axis=1)
self.linear_bias = None
if self.with_bias:
@@ -293,7 +304,7 @@ class ColumnParallelLinear(LinearBase):
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_bias, split_axis=-1)
_set_var_distributed(self.linear_bias, split_axis=1)
# smooth quant
self.linear_shift = None
@@ -318,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
with_bias: bool = False,
add_bias: bool = False,
activation: str = "gelu",
use_fast_ffn: bool = False,
skip_quant: bool = False,
):
"""
@@ -333,13 +343,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
activation (str): Activation function to use. Defaults to "gelu".
use_fast_ffn (bool): Whether to use a faster FFN implementation.
Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.use_fast_ffn = use_fast_ffn
self.activation = activation
self.embed_dim = fd_config.model_config.hidden_size
self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_degree
super().__init__(fd_config=fd_config,
@@ -374,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"gate_proj")
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
paddle.get_default_dtype())
converted_bias_tensor = paddle.zeros(shape=list(
bias_tensor.shape),
dtype=bias_tensor.dtype)
if not self.use_fast_ffn:
converted_bias_tensor = paddle.concat(
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
else:
converted_bias_tensor = bias_tensor
state_dict[self.bias_key] = converted_bias_tensor
if not self.use_fast_ffn:
converted_weight_tensor = paddle.concat(
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
else:
converted_weight_tensor = weight_tensor
state_dict[self.bias_key] = bias_tensor
state_dict[self.weight_key] = converted_weight_tensor
state_dict[self.weight_key] = weight_tensor
super().load_state_dict(state_dict)
@@ -413,12 +407,12 @@ class QKVParallelLinear(ColumnParallelLinear):
"""
self.num_heads = fd_config.model_config.num_attention_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads
self.embed_dim = fd_config.model_config.hidden_size
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
input_size = self.embed_dim
input_size = self.hidden_size
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
super().__init__(fd_config=fd_config,
prefix=prefix,
@@ -448,7 +442,7 @@ class QKVParallelLinear(ColumnParallelLinear):
weight_tensor = weight_tensor.reshape([
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) *
(self.head_dim),
self.embed_dim,
self.hidden_size,
])
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
@@ -513,6 +507,7 @@ class RowParallelLinear(LinearBase):
output_size: int = None,
with_bias: bool = False,
add_bias: bool = False,
reduce_results: bool = True,
skip_quant: bool = False,
):
"""
@@ -538,10 +533,14 @@ class RowParallelLinear(LinearBase):
self.fd_config = fd_config
self.skip_quant = False
self.nranks = fd_config.parallel_config.tensor_parallel_degree
self.embed_dim = fd_config.model_config.hidden_size
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
# Split input_size when using TP inference.
self.input_size = divide(input_size, self.nranks)
self.output_size = output_size
self.linear_weight_shape = [
self.input_size,
self.output_size,
@@ -551,6 +550,8 @@ class RowParallelLinear(LinearBase):
if fd_config.quant_config:
self.quant_method = fd_config.quant_config.get_quant_method(self)
self.quant_method.create_weights(self)
self.reduce_results = reduce_results
self.init_weight()
def init_weight(self):
@@ -570,7 +571,7 @@ class RowParallelLinear(LinearBase):
self.linear_bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
shape=[self.embed_dim],
shape=[self.hidden_size],
dtype=self._dtype,
is_bias=True,
)
@@ -589,7 +590,7 @@ class RowParallelLinear(LinearBase):
else:
out = paddle.matmul(x, self.linear_weight)
if self.nranks > 1:
if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out)
return out

View File

@@ -14,10 +14,15 @@
# limitations under the License.
"""
from typing import Dict, Optional
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from .utils import get_tensor
@@ -28,12 +33,12 @@ class ParallelLMHead(nn.Layer):
def __init__(
self,
fd_config,
num_embeddings,
embedding_dim,
prefix="",
with_bias=False,
):
fd_config: FDConfig,
num_embeddings: int,
embedding_dim: int,
prefix: str = "",
with_bias: bool = False,
) -> None:
"""
Parallelized LMhead.
@@ -43,21 +48,22 @@ class ParallelLMHead(nn.Layer):
num_attention_heads, and ffn_hidden_size.
num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state.
prefix (str): full name of the layer in the state dict
prefix (str): The name of current layer. Defaults to "".
with_bias (bool): whether to have bias. Default: False.
"""
super(ParallelLMHead, self).__init__()
self.linear_weight_key = prefix + ".weight"
self.linear_weight_key: str = prefix + ".weight"
if with_bias:
self.linear_bias_key = prefix + ".bias"
self.linear_bias_key: Optional[str] = prefix + ".bias"
else:
self.linear_bias_key = None
self.use_ep = fd_config.parallel_config.use_ep
self.linear_bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
if self.use_ep:
self.weight = self.create_parameter(
@@ -92,7 +98,8 @@ class ParallelLMHead(nn.Layer):
fuse_matmul_bias=False, # False diff更小
)
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
@@ -122,7 +129,7 @@ class ParallelLMHead(nn.Layer):
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
def forward(self, input):
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
"""
Defines the forward computation of the layer.

View File

@@ -22,14 +22,35 @@ from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from ..utils import get_tensor, create_and_set_parameter
from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor
from .fused_moe_backend_base import MoEMethodBase
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
moe_expert_reduce, noaux_tc)
# used for deepseek_v3
def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k,
routed_scaling_factor,
e_score_correction_bias) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores
class CutlassMoEMethod(MoEMethodBase):
"""
@@ -199,23 +220,47 @@ class CutlassMoEMethod(MoEMethodBase):
"""
Paddle Cutlass compute Fused MoE.
"""
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
topk_only_mode=False,
)
if layer.topk_method == "noaux_tc":
gate_out = get_moe_scores(gate_out, layer.n_group,
layer.topk_group, layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
topk_only_mode=True,
)
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
topk_only_mode=False,
)
if self.moe_quant_type != "w4a8":
# only w4a8 need expert_idx_per_token
@@ -234,11 +279,11 @@ class CutlassMoEMethod(MoEMethodBase):
permute_indices_per_token,
topk_idx,
None,
norm_topk_prob=True,
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)
if layer.tp_size > 1:
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out

View File

@@ -195,8 +195,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
hidden_size = layer.hidden_size
num_experts = layer.num_experts
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,

View File

@@ -17,6 +17,7 @@
import paddle
from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
@@ -25,17 +26,24 @@ from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase
try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
from .triton_moe_kernels import fused_moe_kernel_paddle
except:
pass
class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused MoE.
"""
def __init__(self, quant_method=None):
def __init__(self, quant_config=None):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_method = quant_method
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
@@ -52,7 +60,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
assert layer.quant_method.quant_config.name() == "wint8"
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert ffn1_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
@@ -63,9 +75,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
if self.quant_config.name() == "wint8":
if algo == "wint8":
max_bound = 127
elif self.quant_config.name() == "wint4":
elif algo == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
@@ -111,15 +123,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
topk_weights, topk_ids = paddle.topk(scores,
k=top_k,
axis=-1,
sorted=False)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -139,14 +149,12 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from .triton_moe_kernels import fused_moe_kernel_paddle
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
max_num_tokens_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
x,
@@ -158,10 +166,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
moe_intermediate_size * 2,
hidden_size,
max_num_tokens_padded,
max_possible_num_post_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
@@ -193,8 +201,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
@@ -205,10 +214,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
hidden_size,
moe_intermediate_size,
max_num_tokens_padded,
max_possible_num_post_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
@@ -324,7 +333,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
topk_weights, topk_ids = paddle.topk(scores,
@@ -352,13 +360,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
max_num_tokens_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
adamard_matrix = create_hadamard_matrix_map[hidden_size]
x = paddle.matmul(x.cast("float32"), adamard_matrix)
@@ -371,8 +379,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
permute_x = permute_x / quant_activation_scale
permute_x = permute_x.astype("float8_e4m3fn")
from .triton_moe_kernels import fused_moe_kernel_paddle
fused_moe_kernel_paddle[grid](
permute_x,
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
@@ -383,10 +389,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
moe_intermediate_size * 2,
hidden_size,
max_num_tokens_padded,
max_possible_num_post_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
@@ -426,8 +432,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
intermediate_cache2 = intermediate_cache2 / quant_activation_scale
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn")
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
@@ -439,10 +446,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
hidden_size,
moe_intermediate_size,
max_num_tokens_padded,
max_possible_num_post_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],

View File

@@ -224,6 +224,7 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
)
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,

View File

@@ -30,10 +30,15 @@ class FusedMoE(nn.Layer):
def __init__(
self,
fd_config,
reduce_results: bool = True,
moe_intermediate_size: int = -1,
num_experts: int = -1,
expert_id_offset: int = 0,
top_k: int = -1,
topk_method: str = "",
topk_group: int = -1,
n_group: int = -1,
routed_scaling_factor: float = 1.0,
layer_idx: int = -1,
moe_tag: str = "",
weight_key_map: dict = {},
@@ -49,6 +54,7 @@ class FusedMoE(nn.Layer):
self.fd_config = fd_config
self.layer_idx = layer_idx
self.reduce_results = reduce_results
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
self.ep_size = fd_config.parallel_config.expert_parallel_degree
@@ -60,28 +66,33 @@ class FusedMoE(nn.Layer):
self.hidden_size = fd_config.model_config.hidden_size
self.moe_config = fd_config.moe_config
self.num_experts = num_experts
self.num_local_experts = self.num_experts // self.ep_size
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
self.top_k = top_k
self.hidden_size = self.hidden_size
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
self.weight_key_map = weight_key_map
self.use_method = envs.FD_MOE_BACKEND.lower()
self.gate_correction_bias = None
self.moe_tag = moe_tag
if self.ep_size > 1:
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
self.expert_id_offset = expert_id_offset
if fd_config.quant_config:
self.quant_method = fd_config.quant_config.get_quant_method(self)
# used for deepseek_v3
self.topk_method = topk_method
self.topk_group = topk_group
self.n_group = n_group
self.routed_scaling_factor = routed_scaling_factor
moe_quant_config = fd_config.quant_config
self.moe_quant_type = None
if moe_quant_config:
self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name()
else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
from .fused_moe_cutlass_backend import CutlassMoEMethod
@@ -90,12 +101,78 @@ class FusedMoE(nn.Layer):
if self.ep_size > 1:
self.quant_method.init_ep(self)
if fd_config.load_config.dynamic_load_weight:
# It's for RL to build model
self.init_moe_weights()
logger.info(
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
, ep_size={self.ep_size}, \
tp_size={self.tp_size}.")
def init_moe_weights(self):
"""
Initialize the weight shapes and parameters for the MoE layer.
Combines weight shape initialization and parameter creation into a single function.
"""
# Initialize weight shapes
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
gate_weight_shape = [self.hidden_size, self.num_experts]
gate_correction_bias_shape = [1, self.num_experts]
self.gate_weight = self.create_parameter(
shape=gate_weight_shape,
dtype="float32",
)
if self.moe_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_shape,
dtype="float32",
)
ffn1_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]:
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
else:
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
# Create parameters
if self.moe_quant_type == "fp8":
#(TODO:gaoziyuan)
pass
elif self.moe_quant_type == "wint8":
self.weight_dtype = "int8"
self.init_weight_only_scale()
# FFN1 parameters
self.moe_ffn1_weight = self.create_parameter(
shape=ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
# FFN2 parameters
self.moe_ffn2_weight = self.create_parameter(
shape=ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
def init_weight_only_scale(self):
"""
Initialize the weight scale.
"""
self.moe_ffn1_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
dtype=self._dtype,
)
self.moe_ffn2_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size],
dtype=self._dtype,
)
def load_experts_weight(self, state_dict: dict,
ffn1_expert_weight_key: str,
ffn2_expert_weight_key: str):

View File

@@ -16,9 +16,10 @@
import triton
import triton.language as tl
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2
@triton.jit
@paddle_use_triton_v2()
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,
@@ -31,22 +32,22 @@ def fused_moe_kernel_paddle(
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
num_tokens_post_padded,
max_possible_num_post_padded,
num_valid_tokens,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
N: tl.constexpr,
K: tl.constexpr,
stride_am: tl.constexpr,
stride_ak: tl.constexpr,
stride_be: tl.constexpr,
stride_bk: tl.constexpr,
stride_bn: tl.constexpr,
stride_cm: tl.constexpr,
stride_cn: tl.constexpr,
stride_asm: tl.constexpr,
stride_ask: tl.constexpr,
stride_bse: tl.constexpr,
stride_bsk: tl.constexpr,
stride_bsn: tl.constexpr,
# Block size for block-wise fp8 quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
@@ -87,7 +88,7 @@ def fused_moe_kernel_paddle(
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group

View File

@@ -0,0 +1,133 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
from paddle.distributed import fleet
from .utils import get_tensor
class ParallelEHProjection(nn.Layer):
"""
"Parallelized Embedding Hidden States Projection.
"""
def __init__(
self,
fd_config,
num_embeddings,
embedding_dim,
prefix="",
with_bias=False,
):
"""
Parallelized Embedding Hidden States Projection.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state.
prefix (str): full name of the layer in the state dict
"""
super(ParallelEHProjection, self).__init__()
self.linear_weight_key = prefix + ".weight"
if with_bias:
self.linear_bias_key = prefix + ".bias"
else:
self.linear_bias_key = None
self.use_ep = fd_config.parallel_config.use_ep
self.column_cut = True
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
)
else:
if self.column_cut:
need_gather = True
self.out_linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
self.out_linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
def load_state_dict(self, state_dict):
"""
Load the checkpoint state dictionary into the layer.
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(
state_dict.pop(self.linear_weight_key)).astype(
paddle.get_default_dtype())
if self.out_linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.out_linear.weight.set_value(weight_tensor)
if self.linear_bias_key is not None:
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
def forward(self, input):
"""
Defines the forward computation of the layer.
Args:
input (Tensor): The input tensor to the layer.
Returns:
Tensor: The output tensor after processing through the layer.
"""
logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.out_linear(logits)
return logits

View File

@@ -14,10 +14,15 @@
# limitations under the License.
"""
from typing import Callable, Dict, Optional
import numpy as np
import paddle
from paddle import nn
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.config import FDConfig
from .utils import get_tensor
@@ -28,16 +33,16 @@ class RMSNorm(nn.Layer):
def __init__(
self,
fd_config,
hidden_size,
eps=1e-5,
prefix="",
linear_bias=None,
quant_scale=None,
begin_norm_axis=1,
):
fd_config: FDConfig,
hidden_size: int,
eps: float = 1e-5,
prefix: str = "",
linear_bias: paddle.Tensor = None,
quant_scale: float = None,
begin_norm_axis: int = 1,
) -> None:
"""
Initializes the normalization layer.
Initializes the RMSNormalization layer.
Args:
fd_config (FDConfig): Arguments related to inference, containing
@@ -45,33 +50,33 @@ class RMSNorm(nn.Layer):
num_attention_heads, and ffn_hidden_size.
hidden_size (int) : size of hidden state.
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
weight_key (str): Key name of weight in the pdparams state dict. Defaults to None, means no weight.
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
prefix(str,optional):The name of current layer. Defaults to "".
linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
Raises:
NotImplementedError: If the specified norm_type is not supported.
"""
super().__init__()
self.fd_config = fd_config
self.prefix = prefix
self.hidden_size = hidden_size
self.prefix: str = prefix
self.hidden_size: int = hidden_size
if len(prefix) == 0:
self.weight_key = None
self.weight_key: Optional[str] = None
else:
self.weight_key = f"{prefix}.weight"
self.with_weight = self.weight_key is not None
self.eps = eps
self.norm_func = fused_rms_norm
self.linear_bias = linear_bias
self.quant_scale = quant_scale
self._dtype = self._helper.get_default_dtype()
self._norm_weight_dtype = self._dtype
self.begin_norm_axis = begin_norm_axis
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.begin_norm_axis = begin_norm_axis
self.weight_key: Optional[str] = f"{prefix}.weight"
self.with_weight: bool = self.weight_key is not None
self.eps: float = eps
self.norm_func: Callable = fused_rms_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.quant_scale: Optional[float] = quant_scale
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = self._dtype
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.begin_norm_axis: int = begin_norm_axis
self.init_weight()
@@ -88,7 +93,8 @@ class RMSNorm(nn.Layer):
dtype=self._norm_weight_dtype,
)
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
@@ -102,7 +108,10 @@ class RMSNorm(nn.Layer):
self._norm_weight_dtype)
self.ln_weight.set_value(weight_tensor)
def forward(self, x, residual_input=None):
def forward(
self,
x,
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
@@ -140,18 +149,18 @@ class RMSNorm(nn.Layer):
class LayerNorm(nn.Layer):
"""
Normalization layer.
Initializes the LayerNormalization layer
"""
def __init__(
self,
fd_config,
hidden_size,
eps=1e-5,
fd_config: FDConfig,
hidden_size: int,
eps: float = 1e-5,
prefix="",
linear_bias=None,
quant_scale=None,
with_bias=False,
linear_bias: paddle.Tensor = None,
quant_scale: float = None,
with_bias: bool = False,
):
"""
Initializes the normalization layer.
@@ -160,35 +169,37 @@ class LayerNorm(nn.Layer):
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
hidden_size (int) : size of hidden state.
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
with_bias (bool):Whether to include bias or not. Defaults to False.
Raises:
NotImplementedError: If the specified norm_type is not supported.
"""
super().__init__()
self.fd_config = fd_config
self.prefix = prefix
self.hidden_size = hidden_size
self.prefix: str = prefix
self.hidden_size: int = hidden_size
if len(prefix) == 0:
self.weight_key = None
self.weight_key: Optional[str] = None
else:
self.weight_key = f"{prefix}.weight"
self.with_weight = self.weight_key is not None
self.bias_key = f"{prefix}.bias"
self.with_bias = with_bias
self.eps = eps
self.weight_key: Optional[str] = f"{prefix}.weight"
self.with_weight: bool = self.weight_key is not None
self.bias_key: str = f"{prefix}.bias"
self.with_bias: bool = with_bias
self.eps: float = eps
self.quant_scale: float = quant_scale
self.norm_func: Callable = fused_layer_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = "float32"
self.norm_func = fused_layer_norm
self.linear_bias = linear_bias
self._dtype = self._helper.get_default_dtype()
self._norm_weight_dtype = "float32"
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.init_weight()
@@ -212,7 +223,8 @@ class LayerNorm(nn.Layer):
dtype=self._norm_weight_dtype,
)
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
@@ -233,7 +245,10 @@ class LayerNorm(nn.Layer):
self._norm_weight_dtype)
self.ln_bias.set_value(bias_tensor)
def forward(self, x, residual_input=None):
def forward(
self,
x,
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
@@ -259,7 +274,7 @@ class LayerNorm(nn.Layer):
begin_norm_axis=1,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,

View File

@@ -15,8 +15,9 @@
"""
from typing import Optional
from ..attention import Attention
from ..moe import FusedMoE
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from . import get_quantization_config
from .quant_base import QuantConfigBase, QuantMethodBase

View File

@@ -132,18 +132,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8"
linear_weight_scale_shape = [layer.embed_dim]
if hasattr(layer, "linear_weight_shape"):
if isinstance(layer.linear_weight_shape, list):
layer_weight_shape = layer.linear_weight_shape
linear_weight_scale_shape = layer_weight_shape[:1]
if self.quant_config.name() == "wint4":
linear_weight_scale_shape[0] *= 2
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
dtype=layer._dtype,
@@ -195,6 +191,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
weight,
algo=self.quant_config.algo,

View File

@@ -14,13 +14,18 @@
# limitations under the License.
"""
from typing import Optional
import math
from typing import Optional, Tuple
import paddle
import paddle.nn as nn
from fastdeploy.config import ModelConfig
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
from .utils import CpuGuard
@@ -99,20 +104,164 @@ class QwenRotaryEmbedding:
return rot_emb
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
"""
"""
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
"""
"""
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def yarn_linear_ramp_mask(min, max, dim):
"""
"""
if min == max:
max += 0.001 # Prevent singularity
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max -
min)
ramp_func = paddle.clip(linear_func, 0, 1)
return ramp_func
class DeepseekScalingRotaryEmbedding(nn.Layer):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
Args:
rotary_dim(int): Dimension of rotary embeddings (head dimension)
max_position_embeddings(int): Original training context length
base(float): Base value used to compute the inverse frequencies.
scaling_factor(float): Context extension scaling ratio (target_len / original_len)
extrapolation_factor(float): Weight for extrapolated frequencies (default=1)
attn_factor(float): Attention magnitude scaling factor (default=1)
beta_fast(int): High-frequency correction cutoff (default=32)
beta_slow(int): Low-frequency correction cutoff (default=1)
mscale(float): Primary magnitude scaling factor (default=1)
mscale_all_dim(float): Alternate magnitude scaling factor (default=0)
"""
def __init__(
self,
rotary_dim: int,
max_position_embeddings: int,
base: int,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
super().__init__()
self._dtype = paddle.get_default_dtype()
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: paddle.Tensor
self.register_buffer("cos_sin_cache", cache, persistable=True)
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
pos_freqs = self.base**(
paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(self.max_position_embeddings * self.scaling_factor,
dtype=paddle.float32)
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = paddle.concat((cos, sin), axis=-1)
return cache.cast(self._dtype)
def forward(
self,
position_ids: paddle.Tensor,
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
"""
# In-place operations that update the query and key tensors.
fused_rotary_position_encoding(query, key, position_ids,
self.cos_sin_cache, self.rotary_dim,
False)
return query, key
def get_rope_impl(
rotary_dim: int,
base: 10000.0,
position_ids,
position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1,
):
) -> paddle.Tensor:
"""
The real implementation of get_rope
"""
architecture = model_config.architectures[0]
if model_config is not None and model_config is None or architecture.startswith(
"Qwen"):
if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
@@ -126,10 +275,10 @@ def get_rope_impl(
def get_rope_xpu(
rotary_dim: int,
base: 10000.0,
position_ids,
model_config: ModelConfig,
position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None,
partial_rotary_factor=1,
):
) -> paddle.Tensor:
"""
In XPU, cos and sin compute must be done on cpu
"""
@@ -143,12 +292,27 @@ def get_rope_xpu(
def get_rope(
rotary_dim: int,
base: 10000.0,
position_ids,
model_config: ModelConfig,
partial_rotary_factor=1,
):
position_ids: paddle.Tensor,
model_config: Optional[ModelConfig] = None,
partial_rotary_factor: int = 1,
) -> paddle.Tensor:
"""
The warpper of get_rope
Pre-calculate rotary position embedding for position_ids.
Args:
rotary_dim (int):
Dimension of rotary embeddings (head dimension)
base (float, optional):
Base value used to compute the inverse frequencies.
Default: 10000.0.
position_ids (paddle.Tensor):
Tensor containing position indices of input tokens.
model_config (Optional[ModelConfig]):
Model configuration object containing architecture information.
If provided, determines RoPE implementation based on model architecture.
partial_rotary_factor (int, optional):
Factor controlling partial rotary application.
Default: 1 (apply to all dimensions).
"""
if current_platform.is_xpu():
return get_rope_xpu(rotary_dim, base, position_ids, model_config,
@@ -255,7 +419,24 @@ def get_rope_3d(
paritial_rotary_factor: 1,
max_position: 131072,
freq_allocation: 2,
):
) -> paddle.Tensor:
"""
Pre-calculate rotary position embedding for position_ids.
Args:
rotary_dim (int):
Dimension of rotary embeddings (head dimension)
base (float, optional):
Base value used to compute the inverse frequencies.
Default: 10000.0.
position_ids (paddle.Tensor):
Tensor containing position indices of input tokens.
partial_rotary_factor (int, optional):
Factor controlling partial rotary application.
Default: 1 (apply to all dimensions).
max_position: Maximum position index to precompute.
freq_allocation: Number of rotary dimensions allocated to temporal axis
"""
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
paritial_rotary_factor,
max_position,

View File

@@ -377,4 +377,4 @@ def create_and_set_parameter(layer: nn.Layer, name: str,
dtype=tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
))
getattr(layer, name).set_value(tensor)
getattr(layer, name).set_value(tensor)

View File

@@ -0,0 +1,289 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
import paddle
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.models.tp_utils import \
check_tensor_parallel_prerequisites
from fastdeploy.platforms import current_platform
def load_ep_checkpoint(model_path: str,
config: ModelConfig,
return_numpy: bool = False):
"""
load ep checkpoint
"""
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
num_local_ffn_keys = []
for i in range(config.moe_layer_start_index, config.num_layers):
for j in range(
config.num_experts_start_offset,
config.num_experts_start_offset + config.num_experts_per_rank,
):
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
ffn2_quant_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
ffn2_scale_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
num_local_ffn_keys.append(ffn1_key)
num_local_ffn_keys.append(ffn2_key)
num_local_ffn_keys.append(ffn1_quant_key)
num_local_ffn_keys.append(ffn2_quant_key)
num_local_ffn_keys.append(ffn1_scale_key)
num_local_ffn_keys.append(ffn2_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths,
desc="Loading safetensor files",
unit="file"):
with safe_open(os.path.join(model_path, safetensor_path),
framework="np",
device="cpu") as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(
paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param
def fastsafetensors_weights_iterator(safetensor_list: list[str], ):
"""
Return an iterator over tensors on GPU from a given safetensor_list.
"""
world_size = dist.get_world_size()
if world_size > 1:
pg = dist.get_group()
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
) else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i:i + world_size]
for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg,
device,
nogds=True,
debug_log=False,
framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def load_pre_sharded_checkpoint(model_path: str,
local_rank: int,
use_fastsafetensor: bool = False):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
_, safetensor_files = get_all_safetensors(
os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def get_all_safetensors(model_path: str):
"""
get_all_safetensors
"""
safe_model_path = os.path.join(model_path, "model.safetensors")
if os.path.exists(safe_model_path):
safetensor_list = [safe_model_path]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, safetensor_list
else:
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
def load_tp_checkpoint_v1(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
use_fastsafetensor: bool = True,
):
"""
load_tp_checkpoint_v1
"""
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
else:
weights_iterator = safetensors_weights_iterator(safetensor_files)
tensor_parallel_filtered_map = {}
check_tensor_parallel_prerequisites(
fd_config,
cls,
tensor_parallel_filtered_map,
safetensor_keys,
)
need_tp = True if tensor_parallel_filtered_map else False
state_dict = {}
for key, weight in weights_iterator:
paddle.device.synchronize()
if need_tp and key in tensor_parallel_filtered_map:
action = tensor_parallel_filtered_map.pop(key)
tensor = action(weight).clone()
else:
tensor = weight.clone()
state_dict[key] = tensor
weight.value().get_tensor()._clear()
return state_dict
def deal_state_dict(state_dict):
"""deal_state_dict"""
device = paddle.CUDAPinnedPlace()
for name, src in state_dict.items():
if src._is_initialized() and not isinstance(src.place,
paddle.CUDAPinnedPlace):
dst = src._copy_to(device, True)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
src_tensor._clear()
src_tensor._share_data_with(dst_tensor)
def load_composite_checkpoint(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
return_numpy=True,
):
"""
# This method supports loading model weights under three parallelism strategies:
# 1. Expert Parallel (EP)
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep:
state_dict = load_ep_checkpoint(model_path,
fd_config.model_config,
return_numpy=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank")
and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_degree != len(
rank_dirs):
raise ValueError(
f"Your model only supports loading with tp{len(rank_dirs)}"
)
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available()
and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path,
cls,
fd_config,
use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(model_path,
cls,
fd_config.model_config,
return_numpy=return_numpy)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict

View File

@@ -20,6 +20,10 @@ import paddle
from paddle import nn
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import \
load_composite_checkpoint
from fastdeploy.model_executor.models.deepseek_v3 import \
DeepSeekV3PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_moe import \
Ernie4_5_PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_mtp import \
@@ -28,7 +32,7 @@ from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
from fastdeploy.model_executor.models.utils import load_checkpoint
from fastdeploy.platforms import current_platform
MODEL_CLASSES = {
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
@@ -36,7 +40,8 @@ MODEL_CLASSES = {
"Qwen2ForCausalLM": Qwen2PretrainedModel,
"Qwen3ForCausalLM": Qwen3PretrainedModel,
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
}
@@ -73,23 +78,43 @@ class DefaultModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
pass
def clean_memory_fragments(self, state_dict: dict) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda():
if state_dict:
for k, v in state_dict.items():
if isinstance(v, paddle.Tensor):
v.value().get_tensor()._clear()
paddle.device.cuda.empty_cache()
paddle.device.synchronize()
def load_model(self, fd_config: FDConfig) -> nn.Layer:
context = paddle.LazyGuard()
architectures = fd_config.model_config.architectures[0]
# TODO(gongshaotian): Now, only support safetensor
if fd_config.load_config.dynamic_load_weight:
# register rl model
import fastdeploy.rl
architectures = architectures + "RL"
model_class = MODEL_CLASSES[architectures]
state_dict = load_checkpoint(
fd_config.parallel_config.model_name_or_path,
model_class,
fd_config.model_config,
return_numpy=True)
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config)
model.eval()
model.set_state_dict(state_dict)
# RL model not need set_state_dict
if fd_config.load_config.dynamic_load_weight:
return model
# TODO(gongshaotian): Now, only support safetensor
model_class = MODEL_CLASSES[architectures]
state_dict = load_composite_checkpoint(
fd_config.parallel_config.model_name_or_path,
model_class,
fd_config,
return_numpy=True,
)
model.set_state_dict(state_dict)
self.clean_memory_fragments(state_dict)
return model

View File

@@ -20,15 +20,6 @@ from pathlib import Path
from .model_base import ModelForCasualLM, ModelRegistry
inference_runner_supported_models = [
"Ernie4_5_MoeForCausalLM",
"Ernie4_5_MTPForCausalLM",
"Qwen2ForCausalLM",
"Qwen3MoeForCausalLM",
"Ernie4_5_ForCausalLM",
"Qwen3ForCausalLM",
]
def _find_py_files(root_dir):
root_path = Path(root_dir)
@@ -44,14 +35,14 @@ def _find_py_files(root_dir):
return py_files
def auto_models_registry():
def auto_models_registry(dir_path,
register_path="fastdeploy.model_executor.models"):
"""
auto registry all models in this folder
"""
for module_file in _find_py_files(os.path.dirname(__file__)):
for module_file in _find_py_files(dir_path):
try:
module = importlib.import_module(
f'fastdeploy.model_executor.models.{module_file}')
module = importlib.import_module(f'{register_path}.{module_file}')
for attr_name in dir(module):
attr = getattr(module, attr_name)
if inspect.isclass(attr) and issubclass(
@@ -62,4 +53,4 @@ def auto_models_registry():
raise ImportError(f"{module_file=} import error")
auto_models_registry()
auto_models_registry(os.path.dirname(__file__))

View File

@@ -0,0 +1,762 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
from functools import partial
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
ReplicatedLinear, RowParallelLinear)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.rotary_embedding import \
DeepseekScalingRotaryEmbedding
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.platforms import current_platform
from fastdeploy.worker.forward_meta import ForwardMeta
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import \
get_position_ids_and_mask_encoder_batch
class DeepSeekV3MLP(nn.Layer):
"""
DeepSeekV3MLP, for Dense FFN and Shared Experts Layer.
"""
def __init__(
self,
fd_config: FDConfig,
intermediate_size: int,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
reduce_results=reduce_results,
)
self.act_fn = SiluAndMul(
fd_config=fd_config,
bias=None,
act_method=fd_config.model_config.hidden_act,
)
def load_state_dict(self, state_dict):
"""
"""
self.gate_up_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
gate_up_out = self.gate_up_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
class DeepSeekV3MoE(nn.Layer):
"""
DeepSeekV3MoE, for MoE Layer.
"""
def __init__(self, fd_config: FDConfig, layer_id: int,
prefix: str) -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.gate.e_score_correction_bias",
"ffn1_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"ffn2_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
}
self.fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.deepseekv3.
moe_intermediate_size,
num_experts=fd_config.model_config.deepseekv3.n_routed_experts,
top_k=fd_config.model_config.deepseekv3.num_experts_per_tok,
topk_method=fd_config.model_config.deepseekv3.topk_method,
topk_group=fd_config.model_config.deepseekv3.topk_group,
n_group=fd_config.model_config.deepseekv3.n_group,
routed_scaling_factor=fd_config.model_config.deepseekv3.
routed_scaling_factor,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
self.num_shared_experts = fd_config.model_config.deepseekv3.n_shared_experts
shared_experts_intermediate_size = (
self.num_shared_experts *
fd_config.model_config.deepseekv3.moe_intermediate_size)
self.shared_experts = DeepSeekV3MLP(
fd_config=fd_config,
intermediate_size=shared_experts_intermediate_size,
prefix=f"{prefix}.shared_experts",
reduce_results=False,
)
def load_state_dict(self, state_dict):
"""
"""
self.fused_moe.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
"""
"""
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.fused_moe(hidden_states)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tp_size > 1:
tensor_model_parallel_all_reduce(moe_out)
return moe_out
class DeepseekV3MLAAttention(nn.Layer):
"""
DeepseekV3MLAAttention
"""
def __init__(self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "") -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
self.hidden_size = fd_config.model_config.hidden_size
self.num_attention_heads = fd_config.model_config.num_attention_heads
self.num_attention_heads_tp = self.num_attention_heads // self.tp_size
# MLA
self.qk_nope_head_dim = fd_config.model_config.deepseekv3.qk_nope_head_dim
self.qk_rope_head_dim = fd_config.model_config.deepseekv3.qk_rope_head_dim
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.v_head_dim = fd_config.model_config.deepseekv3.v_head_dim
self.q_lora_rank = fd_config.model_config.deepseekv3.q_lora_rank
self.kv_lora_rank = fd_config.model_config.deepseekv3.kv_lora_rank
self.attn_softmax_scale = self.qk_head_dim**-0.5
self.rope_theta = fd_config.model_config.rope_theta
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(fd_config=fd_config,
prefix=f"{prefix}.q_a_proj",
input_size=self.hidden_size,
output_size=self.q_lora_rank,
with_bias=False)
self.q_a_layernorm = RMSNorm(fd_config,
hidden_size=self.q_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.q_a_layernorm")
self.q_b_proj = ColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.q_b_proj",
input_size=self.q_lora_rank,
output_size=self.num_attention_heads * self.qk_head_dim,
with_bias=False,
)
else:
assert (self.q_lora_rank is not None
), "self.q_lora_rank is None, Please Check your config."
# 不切TP,跑 W4A16 Gemm
self.kv_a_proj_with_mqa = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.kv_lora_rank + self.qk_rope_head_dim,
with_bias=False)
self.kv_a_layernorm = RMSNorm(fd_config,
hidden_size=self.kv_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.kv_a_layernorm")
self.kv_b_proj = ColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_b_proj",
input_size=self.kv_lora_rank,
output_size=self.num_attention_heads *
(self.qk_nope_head_dim + self.v_head_dim),
with_bias=False,
)
self.o_proj = RowParallelLinear(fd_config,
prefix=f"{prefix}.o_proj",
input_size=self.num_attention_heads *
self.v_head_dim,
output_size=self.hidden_size,
with_bias=False)
self.kv_b_proj_bmm = KVBatchLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_b_proj",
kv_lora_rank=self.kv_lora_rank,
num_attention_heads=self.num_attention_heads,
qk_nope_head_dim=self.qk_nope_head_dim,
v_head_dim=self.v_head_dim)
self.rope_scaling = fd_config.model_config.deepseekv3.rope_scaling
if self.rope_scaling:
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
scaling_factor = self.rope_scaling["factor"]
mscale = self.yarn_get_mscale(scaling_factor,
float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
rope_scaling_kwargs = {
key: self.rope_scaling[key]
for key in [
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
] if key in self.rope_scaling
}
self.rope_scaling_factor = self.rope_scaling["factor"]
self.rope_scaling_original_max_position_embeddings = self.rope_scaling[
"original_max_position_embeddings"]
self.rotary_emb = DeepseekScalingRotaryEmbedding(
self.qk_rope_head_dim,
max_position_embeddings=self.
rope_scaling_original_max_position_embeddings,
base=self.rope_theta,
scaling_factor=self.rope_scaling_factor,
**rope_scaling_kwargs,
)
self.mla_attn = Attention(
fd_config=fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=False,
)
self.prefix = prefix
@staticmethod
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
layernorm_out = hidden_states
fmha_out = paddle.zeros(shape=[
layernorm_out.shape[0],
self.num_attention_heads_tp * self.v_head_dim
],
dtype=layernorm_out.dtype)
decode_stage = forward_meta.is_decode_batch
prefill_stage = not (forward_meta.is_decode_batch)
if prefill_stage:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split(
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
key_value = self.kv_b_proj(compressed_kv)
key_value = key_value.reshape([
-1, self.num_attention_heads_tp,
self.qk_nope_head_dim + self.v_head_dim
])
key_nope, value = key_value.split(
[self.qk_nope_head_dim, self.v_head_dim], axis=-1)
query[..., self.qk_nope_head_dim:] = query_pe
key = paddle.empty_like(query)
key[..., :self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim:] = key_pe
value = paddle.nn.functional.pad(
value, [0, self.qk_head_dim - self.v_head_dim], value=0)
fmha_out_prefill = self.mla_attn(q=query,
k=key,
v=value,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta)
fmha_out_prefill = fmha_out_prefill.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, :self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape(
[-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(
fmha_out_prefill.dtype)
fmha_out = fmha_out + fmha_out_prefill
if decode_stage:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
ln_out_or_q_c = query
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split(
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query = self.q_b_proj(ln_out_or_q_c)
query = query.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]),
proj_type='k').transpose([1, 0, 2])
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
q_input = q_input.reshape([
-1,
self.num_attention_heads_tp *
(self.kv_lora_rank + self.qk_rope_head_dim),
])
fmha_out_decode = self.mla_attn(q=q_input,
k=None,
v=None,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta)
fmha_out_decode = fmha_out_decode.reshape(
[-1, self.num_attention_heads_tp,
self.kv_lora_rank]).transpose([1, 0, 2])
fmha_out_decode = (self.kv_b_proj_bmm(
fmha_out_decode, proj_type='v').transpose([1, 0, 2]).reshape(
[-1, self.num_attention_heads_tp * self.v_head_dim]))
fmha_out = fmha_out + fmha_out_decode
output = self.o_proj(fmha_out)
return output
def load_state_dict(self, state_dict):
"""
"""
self.q_a_proj.load_state_dict(state_dict)
self.q_a_layernorm.load_state_dict(state_dict)
self.kv_a_proj_with_mqa.load_state_dict(state_dict)
self.kv_a_layernorm.load_state_dict(state_dict)
self.q_b_proj.load_state_dict(state_dict)
self.kv_b_proj_bmm.load_state_dict(state_dict)
self.kv_b_proj.load_state_dict(state_dict)
# NOTE(Ryan):Make sure kv_b_proj_bmm loaded before kv_b_proj,
# The same weight key will be poped after kv_b_proj.
self.o_proj.load_state_dict(state_dict)
class DeepSeekV3DecoderLayer(nn.Layer):
"""
DeepSeekV3DecoderLayer
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
) -> None:
super().__init__()
layer_id = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV3MLAAttention(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.self_attn",
)
if (fd_config.model_config.deepseekv3.n_routed_experts is not None
and layer_id
>= fd_config.model_config.deepseekv3.first_k_dense_replace):
self.mlp = DeepSeekV3MoE(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepSeekV3MLP(
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
)
self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
)
def load_state_dict(self, state_dict):
"""
"""
self.self_attn.load_state_dict(state_dict)
self.mlp.load_state_dict(state_dict)
self.input_layernorm.load_state_dict(state_dict)
self.post_attention_layernorm.load_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
residual: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(forward_meta, hidden_states,
position_ids, mask_encoder_batch)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepSeekV3Model(nn.Layer):
"""
DeepSeekV3Model
"""
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the DeepSeekV3Model class.
"""
super().__init__()
self.num_layers = fd_config.model_config.num_layers
fd_config.model_config.prefix_name = "deepseek_v3"
self.embeddings = VocabParallelEmbedding(
fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype(),
prefix="deepseek_v3.embed_tokens",
)
self.decoder_layers = nn.LayerList([
DeepSeekV3DecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
for i in range(self.num_layers)
])
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix="deepseek_v3.norm",
)
def pre_process(self, forward_meta):
"""
"""
seq_lens_encoder = forward_meta.seq_lens_encoder
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
position_ids_shape = paddle.sum(seq_lens_this_time)
position_ids = paddle.empty(shape=position_ids_shape,
dtype=seq_lens_encoder.dtype)
mask_encoder_batch = paddle.empty(
shape=position_ids_shape,
dtype=seq_lens_encoder.dtype).unsqueeze(1)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch)
return position_ids, mask_encoder_batch
def load_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
"""
self.embeddings.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.decoder_layers[i].load_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.decoder_layers[i](
forward_meta, hidden_states, residual, position_ids,
mask_encoder_batch)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
return out
class DeepseekV3ForCausalLM(ModelForCasualLM):
"""
DeepseekV3ForCausalLM
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super().__init__(fd_config)
self.model = DeepSeekV3Model(fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
@classmethod
def name(cls):
"""
"""
return "DeepseekV3ForCausalLM"
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
"""
self.model.load_state_dict(state_dict)
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
"""
"""
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits[:, self.ori_vocab_size:] = -float("inf")
return logits
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
hidden_states = self.model(ids_remove_padding, forward_meta)
return hidden_states
class DeepSeekV3PretrainedModel(PretrainedModel):
"""
DeepSeekV3PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
logger.info("DeepseekV3 inference model _get_tensor_parallel_mappings")
from paddleformers.transformers.conversion_utils import \
split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn,
is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial(
fn, is_column=True)
base_actions[
"layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial(
fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(
fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[
f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(
fn, is_column=True)
base_actions[
f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions[
f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(
fn, is_column=False)
# Shared Expert Layer
base_actions[
"layers.0.mlp.shared_experts.up_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.mlp.shared_experts.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.mlp.shared_experts.down_proj.weight"] = partial(
fn, is_column=False)
# MTP parts
base_actions["layers.61.embed_tokens.weight"] = partial(
fn, is_column=False)
base_actions["layers.61.eh_proj.weight"] = partial(fn,
is_column=True)
base_actions["layers.61.shared_head.head.weight"] = partial(
fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_layers)
return mappings

View File

@@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -37,291 +37,13 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import \
LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
from fastdeploy.worker.forward_meta import ForwardMeta
class Ernie4_5_PretrainedModel(PretrainedModel):
"""
Ernie4_5_PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
logger.info("erine inference model _get_tensor_parallel_mappings")
from paddleformers.transformers.conversion_utils import \
split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
def gqa_qkv_split_func(
weight,
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
head_dim,
):
def get_shape(tensor):
return (tensor.get_shape()
if hasattr(tensor, "get_shape") else tensor.shape)
def slice_tensor(tensor, start, end):
shape = get_shape(tensor)
if len(shape) == 1:
return tensor[start:end]
else:
return tensor[..., start:end]
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(weight, 0, q_end)
k = slice_tensor(weight, q_end, k_end)
v = slice_tensor(weight, k_end, v_end)
def split_tensor(tensor, degree):
shape = get_shape(tensor)
size = shape[-1]
block_size = size // degree
if hasattr(tensor, "get_shape"):
return [
slice_tensor(tensor, i * block_size,
(i + 1) * block_size)
for i in range(degree)
]
else:
return np.split(tensor, degree, axis=-1)
q_list = split_tensor(q, tensor_parallel_degree)
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
if tensor_parallel_rank is None:
return [
np.concatenate([q_i, k_i, v_i], axis=-1)
for q_i, k_i, v_i in zip(q_list, k_list, v_list)
]
else:
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=-1,
)
def gqa_qkv_merge_func(weight_list, num_attention_heads,
num_key_value_heads, head_dim):
tensor_parallel_degree = len(weight_list)
num_attention_heads = num_attention_heads // tensor_parallel_degree
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
def get_shape(tensor):
return (tensor.get_shape()
if hasattr(tensor, "get_shape") else tensor.shape)
def slice_tensor(tensor, start, end):
if len(get_shape(tensor)) == 1:
return tensor[start:end]
else:
return tensor[..., start:end]
q_list, k_list, v_list = [], [], []
for weight in weight_list:
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(weight, 0, q_end)
k = slice_tensor(weight, q_end, k_end)
v = slice_tensor(weight, k_end, v_end)
q_list.append(q)
k_list.append(k)
v_list.append(v)
merged = q_list + k_list + v_list
if is_paddle_tensor:
tensor = paddle.concat(merged, axis=-1)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor
else:
return np.concatenate(merged, axis=-1)
if (config.num_key_value_heads is not None
and config.num_key_value_heads != config.num_attention_heads):
if is_split:
qkv_fn = partial(
gqa_qkv_split_func,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
else:
qkv_fn = partial(
gqa_qkv_merge_func,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
else:
qkv_fn = partial(fn, is_column=True)
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
moe_num_shared_experts,
moe_layer_start_index):
final_actions = {}
base_model_prefix = "ernie"
base_actions = {
"lm_head.weight":
partial(fn, is_column=True),
# "eh_proj.weight": partial(fn, is_column=True),
f"{base_model_prefix}.embed_tokens.weight":
partial(fn, is_column=False),
}
base_actions[
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.weight"] = qkv_fn
base_actions[
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.quant_weight"] = qkv_fn
base_actions[
f"{base_model_prefix}.layers.0.self_attn.o_proj.weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.0.self_attn.o_proj.quant_weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.0.mlp.down_proj.weight"] = (
partial(fn, is_column=False))
base_actions[
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"] = partial(
fn, is_column=False)
for expert_idx in range(moe_num_experts):
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.up_gate_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.up_gate_proj.quant_weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.down_proj.weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.experts.{expert_idx}.down_proj.quant_weight"] = partial(
fn, is_column=False)
if moe_num_shared_experts > 0:
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.up_gate_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
fn, is_column=True, is_naive_2fuse=True)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.down_proj.weight"] = partial(
fn, is_column=False)
base_actions[
f"{base_model_prefix}.layers.{moe_layer_start_index}"
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
fn, is_column=False, is_naive_2fuse=True)
for key, action in base_actions.items():
if (f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"
in key or
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"
in key
or f"{base_model_prefix}.layers.0.mlp.down_proj.weight"
in key or
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"
in key):
for i in range(moe_layer_start_index):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
elif f"layers.{moe_layer_start_index}.mlp.experts." in key:
for i in range(moe_layer_start_index, num_layers):
final_actions[key.replace(
f"layers.{moe_layer_start_index}.",
f"layers.{i}.")] = action
elif f"layers.{moe_layer_start_index}.mlp.shared_experts." in key:
for i in range(moe_layer_start_index, num_layers):
final_actions[key.replace(
f"layers.{moe_layer_start_index}.",
f"layers.{i}.")] = action
elif f"{base_model_prefix}.layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
moe_num_experts = 0
moe_num_shared_experts = 0
if isinstance(config.moe_num_experts, list):
moe_num_experts = sum(config.moe_num_experts)
elif isinstance(config.moe_num_experts, int):
moe_num_experts = config.moe_num_experts
if hasattr(config, 'moe_num_shared_experts'):
moe_num_shared_experts = config.moe_num_shared_experts
moe_layer_start_index = -1
if isinstance(config.moe_layer_start_index, list):
moe_layer_start_index = min(config.moe_layer_start_index)
elif isinstance(config.moe_layer_start_index, int):
moe_layer_start_index = config.moe_layer_start_index
mappings = get_tensor_parallel_split_mappings(
config.num_layers,
moe_num_experts,
moe_num_shared_experts,
moe_layer_start_index,
)
return mappings
class Ernie4_5_MLP(nn.Layer):
def __init__(
@@ -329,6 +51,7 @@ class Ernie4_5_MLP(nn.Layer):
fd_config: FDConfig,
intermediate_size: int,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_degree
@@ -339,13 +62,12 @@ class Ernie4_5_MLP(nn.Layer):
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=(intermediate_size // self.nranks),
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
)
@@ -423,8 +145,8 @@ class Ernie4_5_MoE(nn.Layer):
f"{prefix}.experts.{{}}.down_proj.code_zp",
}
elif moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and
fd_config.model_config.is_quantized):
moe_quant_type == "block_wise_fp8"
and fd_config.model_config.is_quantized):
weight_key_map = {
"gate_weight_key":
f"{prefix}.gate.weight",
@@ -492,8 +214,6 @@ class Ernie4_5_Attention(nn.Layer):
prefix: str) -> None:
super().__init__()
nranks = fd_config.parallel_config.tensor_parallel_degree
self.qkv_proj = QKVParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.qkv_proj",
@@ -502,8 +222,8 @@ class Ernie4_5_Attention(nn.Layer):
self.o_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.o_proj",
input_size=(fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads // nranks),
input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
)
self.attn = Attention(
@@ -636,12 +356,12 @@ class Ernie4_5_Model(nn.Layer):
params_dtype=paddle.get_default_dtype(),
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"))
self.hidden_layers = [
self.hidden_layers = nn.LayerList([
Ernie4_5_DecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
for i in range(self.num_layers)
]
])
self.norm = RMSNorm(
fd_config,
@@ -772,3 +492,134 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
Model Architecture Name
"""
return "Ernie4_5_ForCausalLM"
class Ernie4_5_PretrainedModel(PretrainedModel):
"""
Ernie4_5_PretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
weight_infos = [
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
True, tsm.GQA),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight",
False),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
True, tsm.PairFused),
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
False),
WeightMeta(".embed_tokens.weight", False),
WeightMeta("lm_head.weight", True),
# quant tensorwise
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight",
True, tsm.GQA),
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight",
False),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight",
False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight",
True, tsm.PairFused),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight",
False),
]
@classmethod
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
logger.info("erine inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import (
build_expanded_keys, has_prefix, split_or_merge_func_v1)
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim)
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
moe_layer_start_index,
prefix_name):
base_actions = {}
weight_infos = cls.weight_infos
for (weight_name, is_column, extra) in weight_infos:
params = {
"is_column": is_column,
**({
extra.value: True
} if extra else {})
}
if "lm_head.weight" in weight_name:
key = weight_name
elif not has_prefix(prefix_name, weight_name):
key = f"{prefix_name}{weight_name}"
else:
key = weight_name
base_actions[key] = partial(fn, **params)
final_actions = {}
start_layer = (moe_layer_start_index
if moe_layer_start_index > 0 else num_layers)
final_actions = build_expanded_keys(
num_layers,
moe_num_experts,
start_layer,
base_actions,
)
return final_actions
moe_num_experts = 0
if isinstance(config.moe_num_experts, list):
moe_num_experts = sum(config.moe_num_experts)
elif isinstance(config.moe_num_experts, int):
moe_num_experts = config.moe_num_experts
moe_layer_start_index = -1
if isinstance(config.moe_layer_start_index, list):
moe_layer_start_index = min(config.moe_layer_start_index)
elif isinstance(config.moe_layer_start_index, int):
moe_layer_start_index = config.moe_layer_start_index
mappings = get_tensor_parallel_split_mappings(config.num_layers,
moe_num_experts,
moe_layer_start_index,
config.prefix_name)
return mappings

View File

@@ -26,7 +26,7 @@ from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
@@ -265,12 +265,12 @@ class Ernie4_5_MTPModel(nn.Layer):
self.num_layers = fd_config.model_config.num_layers
self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings
self.hidden_layers = [
self.hidden_layers = nn.LayerList([
Ernie4_5_DecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.prefix_name}.{i}")
for i in range(self.num_layers)
]
])
self.enorm = RMSNorm(
fd_config,
@@ -286,7 +286,7 @@ class Ernie4_5_MTPModel(nn.Layer):
prefix="ernie.mtp_hidden_norm.0",
)
self.eh_proj = ParallelLMHead(
self.eh_proj = ParallelEHProjection(
fd_config=fd_config,
num_embeddings=fd_config.model_config.hidden_size,
embedding_dim=fd_config.model_config.hidden_size * 2,

View File

@@ -25,6 +25,8 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
@@ -66,6 +68,7 @@ class Ernie4_5_VLMoE(nn.Layer):
prefix: str) -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
moe_layer_start_index = fd_config.moe_config.moe_layer_start_index
if isinstance(moe_layer_start_index, int):
text_moe_layer_start_index = moe_layer_start_index
@@ -99,6 +102,7 @@ class Ernie4_5_VLMoE(nn.Layer):
}
self.mlp_text = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.moe_config.
moe_intermediate_size[0],
num_experts=fd_config.moe_config.num_experts[0],
@@ -130,6 +134,7 @@ class Ernie4_5_VLMoE(nn.Layer):
}
self.mlp_image = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.moe_config.
moe_intermediate_size[1],
num_experts=fd_config.moe_config.num_experts[1],
@@ -154,6 +159,7 @@ class Ernie4_5_VLMoE(nn.Layer):
intermediate_size=self.num_shared_experts *
fd_config.moe_config.moe_intermediate_size[0],
prefix=f"{prefix}.shared_experts",
reduce_results=False,
)
def extract_gate_correction_bias_text(self, gate_correction_bias_key,
@@ -210,6 +216,8 @@ class Ernie4_5_VLMoE(nn.Layer):
hidden_states = self.mlp_text(hidden_states)
if self.num_shared_experts > 0:
hidden_states += share_experts_out
if self.tp_size > 1:
tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
@@ -337,12 +345,12 @@ class Ernie4_5_VLModel(nn.Layer):
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"),
)
self.hidden_layers = [
self.hidden_layers = nn.LayerList([
Ernie4_5_VLDecoderLayer(
fd_config=fd_config,
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
for i in range(self.num_layers)
]
])
self.norm = RMSNorm(
fd_config,

View File

@@ -29,6 +29,7 @@ class ModelRegistry:
@classmethod
def register(cls, model_class):
"""register model class"""
if issubclass(
model_class,
ModelForCasualLM) and model_class is not ModelForCasualLM:
@@ -37,6 +38,7 @@ class ModelRegistry:
@classmethod
def get_class(cls, name):
"""get model class"""
if name not in cls._registry:
raise ValueError(f"Model '{name}' is not registered!")
return cls._registry[name]

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -55,13 +55,12 @@ class Qwen2MLP(nn.Layer):
output_size=fd_config.model_config.ffn_hidden_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=(fd_config.model_config.ffn_hidden_size // self.nranks),
input_size=fd_config.model_config.ffn_hidden_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
)
@@ -97,8 +96,6 @@ class Qwen2Attention(nn.Layer):
prefix: str = "") -> None:
super().__init__()
nranks = fd_config.parallel_config.tensor_parallel_degree
self.qkv_proj = QKVParallelLinear(fd_config=fd_config,
prefix=f"{prefix}.qkv_proj",
with_bias=True)
@@ -106,7 +103,7 @@ class Qwen2Attention(nn.Layer):
self.o_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.o_proj",
input_size=(fd_config.model_config.hidden_size // nranks),
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.hidden_size,
)
@@ -305,6 +302,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
"""
super(Qwen2ForCausalLM, self).__init__(fd_config)
self.fd_config =fd_config
self.model = Qwen2Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size

View File

@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
@@ -68,7 +68,7 @@ class Qwen3Attention(nn.Layer):
fd_config=fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads // nranks,
fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
)

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention import Attention
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
@@ -57,13 +57,12 @@ class Qwen3MLP(nn.Layer):
output_size=fd_config.model_config.ffn_hidden_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
use_fast_ffn=True,
)
self.down_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.down_proj",
input_size=(fd_config.model_config.ffn_hidden_size // self.nranks),
input_size=fd_config.model_config.ffn_hidden_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
)
@@ -111,7 +110,7 @@ class Qwen3Attention(nn.Layer):
fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads // nranks,
fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
)

View File

@@ -0,0 +1,405 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import re
from enum import Enum
from functools import partial
from typing import Dict, List
import numpy as np
import paddle
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.conversion_utils import split_or_merge_func
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder
def check_tensor_parallel_prerequisites(
fd_config: FDConfig,
cls: PretrainedModel,
tensor_parallel_filtered_map: Dict[str, partial],
safetensor_keys: List[str],
) -> None:
"""check_tensor_parallel_prerequisites"""
if fd_config.parallel_config.tensor_parallel_degree > 1:
tensor_parallel_map = cls._get_tensor_parallel_mappings(
fd_config.model_config, is_split=True)
if not tensor_parallel_map:
logger.error("filtered_quant_map should not be empty. \
parallel splitting required, but _get_tensor_parallel_mappings is not implemented."
)
filtered_tp_keys = cls._resolve_prefix_keys(tensor_parallel_map.keys(),
safetensor_keys)
for k, v in filtered_tp_keys.items():
tensor_parallel_filtered_map[v] = tensor_parallel_map.pop(k)
if not tensor_parallel_filtered_map:
logger.error("tensor_parallel_filtered_map should not be empty. \
The weights required for tensor parallel splitting are inconsistent with the model's weights."
)
def extract_prefix(weight_name: str) -> str:
"""extract_prefix"""
if weight_name.startswith("."):
return ""
parts = weight_name.split(".", 1)
return parts[0] if len(parts) > 1 else ""
def has_prefix(prefix_name: str, weight_name: str):
"""has_prefix"""
return prefix_name == extract_prefix(weight_name)
class TensorSplitMode(Enum):
"""TensorSplitMode"""
GQA = "is_gqa"
TRANSPOSE = "transpose"
QKV = "is_old_qkv"
PairFused = "is_naive_2fuse"
TripletFused = "is_naive_3fuse"
def extract_placeholders(template: str):
"""extract_placeholders"""
return set(re.findall(r"{(\w+)}", template))
class SafeDict(dict):
"""SafeDict"""
def __missing__(self, key):
return "{" + key + "}"
def has_placeholders(placeholders):
"""has_placeholders"""
return len(placeholders) > 0
def update_final_actions(params, final_actions, key, action):
"""update_final_actions"""
new_key = key.format_map(SafeDict(params))
final_actions[new_key] = action
def build_expanded_keys(num_layers, num_experts, start_layer, base_actions):
"""build_expanded_keys"""
final_actions = {}
for key, action in base_actions.items():
placeholders = extract_placeholders(key)
if not has_placeholders(placeholders):
final_actions[key] = action
else:
if LayerIdPlaceholder.LAYER_ID.value in placeholders:
for layer_id in range(num_layers):
update_final_actions(
{LayerIdPlaceholder.LAYER_ID.value: layer_id},
final_actions,
key,
action,
)
elif LayerIdPlaceholder.FFN_LAYER_ID.value in placeholders:
for layer_id in range(start_layer):
update_final_actions(
{LayerIdPlaceholder.FFN_LAYER_ID.value: layer_id},
final_actions,
key,
action,
)
elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders
and LayerIdPlaceholder.EXPERT_ID.value in placeholders):
for layer_id in range(start_layer, num_layers):
for export_id in range(num_experts):
update_final_actions(
{
LayerIdPlaceholder.MOE_LAYER_ID.value:
layer_id,
LayerIdPlaceholder.EXPERT_ID.value: export_id,
},
final_actions,
key,
action,
)
elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders
and len(placeholders) == 1):
for layer_id in range(start_layer, num_layers):
update_final_actions(
{LayerIdPlaceholder.MOE_LAYER_ID.value: layer_id},
final_actions,
key,
action,
)
else:
logger.error(f"{key} does not match any case.")
return final_actions
def gqa_qkv_split_func(
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
head_dim,
):
"""
gqa_qkv_split_func
"""
def fn(x, is_column=True):
"""fucn"""
def get_shape(tensor):
"""get_shape"""
return tensor.get_shape() if hasattr(tensor,
"get_shape") else tensor.shape
def slice_tensor(tensor, start, end):
"""slice_tensor"""
shape = get_shape(tensor)
if len(shape) == 1:
return tensor[start:end]
elif is_column:
return tensor[..., start:end]
else:
return tensor[start:end, ...]
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(x, 0, q_end)
k = slice_tensor(x, q_end, k_end)
v = slice_tensor(x, k_end, v_end)
def split_tensor(tensor, degree):
"""
split_tensor
"""
shape = get_shape(tensor)
size = shape[-1] if is_column else shape[0]
block_size = size // degree
if hasattr(tensor, "get_shape"):
return [
slice_tensor(tensor, i * block_size, (i + 1) * block_size)
for i in range(degree)
]
else:
if isinstance(x, paddle.Tensor):
if is_column:
return paddle.split(tensor, degree, axis=-1)
else:
return paddle.split(tensor, degree, axis=0)
else:
if is_column:
return np.split(tensor, degree, axis=-1)
else:
return np.split(tensor, degree, axis=0)
q_list = split_tensor(q, tensor_parallel_degree)
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
if tensor_parallel_rank is None:
res = []
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
if is_column:
if isinstance(x, paddle.Tensor):
res.append(paddle.concat([q_i, k_i, v_i], axis=-1))
else:
res.append(np.concatenate([q_i, k_i, v_i], axis=-1))
else:
if isinstance(x, paddle.Tensor):
res.append(paddle.concat([q_i, k_i, v_i], axis=0))
else:
res.append(np.concatenate([q_i, k_i, v_i], axis=0))
return res
else:
if isinstance(x, paddle.Tensor):
if is_column:
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=-1,
)
else:
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=0,
)
else:
if is_column:
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=-1,
)
else:
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
],
axis=0,
)
return fn
def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
"""
gqa_qkv_merge_func
"""
def fn(weight_list, is_column=True):
"""fn"""
tensor_parallel_degree = len(weight_list)
num_attention_heads = num_attention_heads // tensor_parallel_degree
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
def get_shape(tensor):
"""
get_shape
"""
return tensor.get_shape() if hasattr(tensor,
"get_shape") else tensor.shape
def slice_tensor(tensor, start, end):
"""
slice_tensor
"""
if len(get_shape(tensor)) == 1:
return tensor[start:end]
elif is_column:
return tensor[..., start:end]
else:
return tensor[start:end, ...]
q_list, k_list, v_list = [], [], []
for weight in weight_list:
q_end = num_attention_heads * head_dim
k_end = q_end + num_key_value_heads * head_dim
v_end = k_end + num_key_value_heads * head_dim
q = slice_tensor(weight, 0, q_end)
k = slice_tensor(weight, q_end, k_end)
v = slice_tensor(weight, k_end, v_end)
q_list.append(q)
k_list.append(k)
v_list.append(v)
merged = q_list + k_list + v_list
if is_paddle_tensor:
if is_column:
tensor = paddle.concat(merged, axis=-1)
else:
tensor = paddle.concat(merged, axis=0)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor
else:
if is_column:
return np.concatenate(merged, axis=-1)
else:
return np.concatenate(merged, axis=0)
return fn
def split_or_merge_qkv_func(
is_split,
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads,
num_key_value_heads,
head_dim,
):
"""
split_or_merge_qkv_func
"""
if is_split:
return gqa_qkv_split_func(
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
)
else:
return gqa_qkv_merge_func(
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
)
def split_or_merge_func_v1(
is_split,
tensor_parallel_degree,
tensor_parallel_rank,
num_attention_heads=None,
num_key_value_heads=None,
head_dim=None,
):
"""
split_or_merge_func_v1
"""
def fn(x, **kwargs):
"""func"""
is_gqa = kwargs.pop("is_gqa", False)
if is_gqa:
func = split_or_merge_qkv_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
)
is_column = kwargs.pop("is_column", True)
return func(x, is_column=is_column)
else:
func = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
num_attention_heads=num_attention_heads,
)
is_column = kwargs.pop("is_column", True)
is_naive_2fuse = kwargs.pop("is_naive_2fuse", False)
return func(x, is_column=is_column, is_naive_2fuse=is_naive_2fuse)
return fn

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import enum
import hashlib
import json
import os
@@ -23,29 +24,47 @@ import random
import re
import struct
from functools import partial
from typing import NamedTuple, Optional
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.common_ops_import import convert_dtype
from paddle.distributed import fleet
from paddleformers.transformers.model_utils import (_add_variant,
load_tp_checkpoint)
from paddleformers.transformers.model_utils import _add_variant
from paddleformers.transformers.utils import paddleformers_load
from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME)
from paddleformers.utils.log import logger
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy.config import ModelConfig
MAX_BSZ = 512
MAX_DRAFT_TOKENS = 6
class LayerIdPlaceholder(str, enum.Enum):
"""LayerIdPlaceholder"""
LAYER_ID = "layer_id"
FFN_LAYER_ID = "ffn_layer_id"
MOE_LAYER_ID = "moe_layer_id"
EXPERT_ID = "export_id"
class WeightMeta(NamedTuple):
"""
#tensor split parameters
# weight_name: weight name
# is_column: whether to split by columns
# extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse"
"""
weight_name: str
is_column: bool
extra: Optional[str] = None
class UniqueIDGenerator:
"""
The generator for the export model id
@@ -433,223 +452,6 @@ def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
return total_effective_tokens, total_tokens
def load_ep_checkpoint(model_path: str,
config: ModelConfig,
return_numpy: bool = False,
return_key_name: bool = True):
"""
load ep checkpoint
"""
# return_numpy=True cpu
# return_numpy=False gpu
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
num_local_ffn_keys = []
for i in range(config.moe_layer_start_index, config.num_layers):
for j in range(
config.num_experts_start_offset,
config.num_experts_start_offset + config.num_experts_per_rank,
):
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
ffn2_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
ffn2_quant_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
ffn2_scale_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
num_local_ffn_keys.append(ffn1_key)
num_local_ffn_keys.append(ffn2_key)
num_local_ffn_keys.append(ffn1_quant_key)
num_local_ffn_keys.append(ffn2_quant_key)
num_local_ffn_keys.append(ffn1_scale_key)
num_local_ffn_keys.append(ffn2_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths,
desc="Loading safetensor files",
unit="file"):
with safe_open(os.path.join(model_path, safetensor_path),
framework="np",
device="cpu") as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(
paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def get_safetensor_file(model_path):
"""
get_safetensor_file
"""
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def fastsafetensors_weights_iterator(safetensor_list: list[str]):
"""
fastsafetensors_weights_iterator
"""
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
world_size = dist.get_world_size()
if world_size > 1:
dist.init_parallel_env()
pg = dist.get_group()
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
) else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i:i + world_size]
for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg,
device,
nogds=True,
debug_log=False,
framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def get_state_dict(model_path, config, use_fastsafetensor=False):
"""
get_state_dict
"""
state_dict = {}
_, safetensor_list = get_safetensor_file(
os.path.join(model_path, f"rank{config.tensor_parallel_rank}"))
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_list)
else:
weights_iterator = safetensors_weights_iterator(safetensor_list)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def apply_quant(name_action_quant_mappings, key, tensor, state_dict):
"""
apply_quant
"""
if key in name_action_quant_mappings:
action = name_action_quant_mappings.pop(key)
quant_weight_tensor, weight_scale_tensor = action(tensor)
if quant_weight_tensor is not None and weight_scale_tensor is not None:
state_dict[key + ".quant_weight"] = quant_weight_tensor
state_dict[key + ".weight_scale"] = weight_scale_tensor
else:
state_dict[key] = quant_weight_tensor
else:
state_dict[key] = tensor
def load_checkpoint(model_path, cls, config, return_numpy=True, load_gpu=True):
"""
load checkpoint
"""
if getattr(config, "parallel_config", None) is not None:
use_ep = getattr(config.parallel_config, "use_ep", False)
tensor_parallel_degree = config.parallel_config.tensor_parallel_degree
else:
use_ep = getattr(config, "use_ep", False)
tensor_parallel_degree = config.tensor_parallel_degree
if getattr(config, "model_config", None) is not None:
model_config = config.model_config
else:
model_config = config
if use_ep:
state_dict = load_ep_checkpoint(model_path,
config,
return_numpy=True,
return_key_name=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank")
and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if tensor_parallel_degree != len(rank_dirs):
raise ValueError(
f"Your model only supports loading with tp{len(rank_dirs)}"
)
state_dict = get_state_dict(model_path, model_config)
else:
state_dict = load_tp_checkpoint(model_path,
cls,
model_config,
return_numpy=return_numpy)
import re
for k, v in state_dict.items():
match = re.search(r'layers\.(\d+)', k)
if match and int(match.group(1)) > 0:
continue
return state_dict
def parser_quant_type(quant_type):
"""
Parse the quantization type string and return the corresponding quantization types for weights,

View File

@@ -13,10 +13,19 @@
# limitations under the License.
"""fastdeploy gpu ops"""
import os
import sys
from fastdeploy.import_ops import import_custom_ops
PACKAGE = "fastdeploy.model_executor.ops.gpu"
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
def tolerant_import_error():
class NoneModule:
def __getattr__(self, name):
return None
sys.modules[__name__] = NoneModule()

View File

@@ -0,0 +1,354 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import importlib
import inspect
import os
import re
import sys
import paddle
import triton
from .triton_utils import (SubstituteTemplate, build_package, compile_file,
extract_triton_kernel, find_so_path,
get_pointer_hint, link_file, multi_process_do,
python_path, rename_c_to_cu)
def get_value_hint(x):
"""
Get the value hint from input list.
"""
hint = ""
for ele in x:
if isinstance(ele, int):
hint += "i64,"
continue
if ele % 16 == 0 and ele > 0:
hint += "i64:16,"
elif ele == 1:
hint += "i64:1,"
else:
hint += "i64,"
if isinstance(ele, float):
hint += "fp32,"
return hint
common_template = ("""
#include "${op_name}_kernel.h"
#include "paddle/extension.h"
void ${op_name}_func(${tensor_and_attr}) {
auto run_stream = a_ptr->stream();
auto res_flag = ${op_name}_kernel(run_stream, ${triton_kernel_args}, 0);
if (res_flag == CUDA_ERROR_INVALID_VALUE) {
PD_THROW("${op_name}_kernel failed");
}
}
PYBIND11_MODULE(${op_name}_package, m) {
m.def("${op_name}_func", ${op_name}_func, "get expert token num");
}
""")
class KernelInterface:
"""
triton kernel interface.
"""
def __init__(
self,
func,
other_config,
key_args=["1"],
):
"""
triton kernel interface.
"""
self.func = func
self.key_args = key_args
signature = inspect.signature(func)
self.arg_names = [v.name for v in signature.parameters.values()]
for ele in self.arg_names:
assert self.arg_names.count(ele) == 1
# arg_defaults = [v.default for v in signature.parameters.values()]
# self.annotations = {
# name: ty for name, ty in func.__annotations__.items()
# }
self.annotations = dict(func.__annotations__)
self.constexprs = [
self.arg_names.index(name) for name in self.arg_names
if self.annotations.get(name) == triton.language.core.constexpr
]
self.arg_exclude_constexpr = [
self.arg_names[i] for i in range(len(self.arg_names))
if i not in self.constexprs
]
import textwrap
py_script = textwrap.dedent(inspect.getsource(func))
pat = r"def\s" + func.__name__
func_begin = re.findall(pat, py_script)
assert len(func_begin) == 1
func_begin = func_begin[0]
py_script = py_script[py_script.find(func_begin):]
self.func_map = {}
def decorator(*args, **kwargs):
"""
decorator for triton kernels.
Args:
*args: positional arguments
**kwargs: keyword arguments
"""
op_name = "haha" + str(kwargs["N"])
if op_name in self.func_map.keys():
return self.func_map[op_name](*args)
all_input = []
for i in range(len(args)):
all_input.append(args[i])
position_arguments_num = len(all_input)
for i in range(position_arguments_num, len(self.arg_names)):
if self.arg_names[i] in kwargs.keys():
all_input.append(kwargs[self.arg_names[i]])
else:
# means this input is not specified, it muse be a tl.constexpr.
assert i in self.constexprs
all_input.append(None)
dtypes = []
x_list = []
const_args = [self.arg_names[i] for i in self.constexprs]
decalare_arg_exclude_constexpr = list(self.arg_exclude_constexpr)
passed_arg_exclude_constexpr = list(self.arg_exclude_constexpr)
const_hint_dict = {}
for i in range(len(all_input)):
ele = all_input[i]
if type(ele) in [
paddle.Tensor, paddle.base.framework.EagerParamBase,
paddle.base.framework.Parameter,
paddle.base.framework.Variable,
paddle.base.libpaddle.pir.Value,
type(None)
]:
if ele is not None:
dtypes.append(ele.dtype)
passed_arg_exclude_constexpr[
i] = f"(CUdeviceptr)({passed_arg_exclude_constexpr[i]}->data())"
else:
dtypes.append(paddle.int8)
passed_arg_exclude_constexpr[
i] = "(CUdeviceptr)(nullptr)"
decalare_arg_exclude_constexpr[
i] = "const paddle::optional<paddle::Tensor>&" + decalare_arg_exclude_constexpr[
i]
elif i in self.constexprs:
if isinstance(ele, bool):
const_hint_dict[self.arg_names[i]] = (int)(ele)
elif isinstance(ele, int):
if ele < 0:
const_hint_dict[self.arg_names[i]] = 0
else:
const_hint_dict[self.arg_names[i]] = ele
else:
assert False
else:
x_list.append(ele)
if isinstance(ele, int):
decalare_arg_exclude_constexpr[
i] = "const int64_t " + decalare_arg_exclude_constexpr[
i]
elif isinstance(ele, float):
decalare_arg_exclude_constexpr[
i] = "const float " + decalare_arg_exclude_constexpr[
i]
else:
assert False
python_package_name = f"{op_name}_package"
tp_rank = paddle.distributed.get_rank()
generated_dir = os.getenv("TRITON_KERNEL_CACHE_DIR",
f"/tmp/triton_cache/rank{tp_rank}")
print("the kernel cache dir is:", generated_dir)
generated_dir = f"{generated_dir}/{op_name}"
os.makedirs(generated_dir, exist_ok=True)
py_script_file = f"{generated_dir}/triton_kernels.py"
extract_triton_kernel(func, py_script_file)
address_hint = get_pointer_hint(dtypes)
value_hint = get_value_hint(x_list)
const_args = [f"{{{ele}}}" for ele in const_args]
const_args = ",".join(const_args)
lanuch_grid = list(self.grid)
for i in range(len(lanuch_grid)):
ele = lanuch_grid[i]
if isinstance(ele, str):
keys = list(const_hint_dict.keys())
keys.sort(key=len, reverse=True)
for key in keys:
if key in ele:
ele = ele.replace(key, f"{const_hint_dict[key]}")
else:
ele = str(ele)
lanuch_grid[i] = ele
if len(lanuch_grid) < 3:
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
lanuch_grid = ",".join(lanuch_grid)
op_dict = {"op_name": op_name}
op_dict["triton_kernel_args"] = ",".join(
passed_arg_exclude_constexpr)
op_dict["tensor_and_attr"] = ",".join(
decalare_arg_exclude_constexpr)
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
so_path = find_so_path(generated_dir, python_package_name)
if so_path is None:
print("== we do not find so_path, we need to compile it")
with open(paddle_custom_op_file_path, "w") as f:
f.write(SubstituteTemplate(
common_template,
op_dict,
))
f.close()
# ahead of time compile command.
aot_template = (
f"""{python_path} {compile_file} {py_script_file} """ +
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
+ f"""--out-name {op_name}_kernel """ +
""" -w {num_warps} -ns {num_stages} """ +
f""" -s"{address_hint} {value_hint} {const_args}" """ +
f""" -g "{lanuch_grid}" """)
all_tune_config = [const_hint_dict]
# reset const_hint_dict as empty.
const_hint_dict = {}
codegen_commands = []
for config in all_tune_config:
for key in const_hint_dict.keys():
if const_hint_dict[key] is not None:
if key not in config.keys():
config[key] = const_hint_dict[key]
else:
if config[key] == const_hint_dict[key]:
pass
else:
message = (
f"you specify {key} both in arguments and config, "
"and they are not same, this is wrong."
)
raise ValueError(message)
else:
assert key in config.keys(
), f"you must specify {key} in your config."
if "num_warps" not in config.keys():
config["num_warps"] = 4
if "num_stages" not in config.keys():
config["num_stages"] = 4
for key in config:
assert config[
key] is not None, f"{key} must be specified."
codegen_command = aot_template.format(**config, )
print(codegen_command)
codegen_commands.append(codegen_command)
multi_process_do(codegen_commands)
link_command = (
f"{python_path} {link_file} "
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
re = os.system(link_command)
assert re == 0
# rename the .c file to .cu
rename_c_to_cu(generated_dir)
# build the package to so, not install
build_package(generated_dir, python_package_name)
# so_path have be found!
so_path = find_so_path(generated_dir, python_package_name)
print("== we find so_path: ", so_path)
assert so_path is not None
dir_path = os.path.dirname(so_path)
sys.path.append(dir_path)
lib = importlib.import_module(python_package_name)
pybind_func = getattr(lib, f"{op_name}_func")
self.func_map[op_name] = pybind_func
# run this op!
self.func_map[op_name](*args)
self.decorator = decorator
def __getitem__(self, op_name_and_grid):
"""
override the operator [], which will call the decorator function.
Args:
op_name_and_grid: the name of the operator and the grid size.
Returns:
the decorator function.
"""
self.grid = ((
"((max_possible_num_post_padded + BLOCK_SIZE_M -1)/ BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N-1) / BLOCK_SIZE_N)"
), )
return self.decorator
def paddle_use_triton_v2(other_config={}, key=[]):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
def decorator(func):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
return KernelInterface(func, other_config, key)
return decorator

View File

@@ -17,6 +17,7 @@ from typing import Dict, Optional
import paddle
from fastdeploy import envs
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
@@ -24,10 +25,11 @@ from fastdeploy.model_executor.ops.gpu import (
speculate_get_padding_offset, speculate_get_seq_lens_output,
speculate_save_output, speculate_set_value_by_flags_and_idx,
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
step_paddle, step_system_cache, update_inputs)
step_paddle, step_system_cache, update_inputs, step_reschedule)
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import ModelOutputData
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
def pre_process(
max_len: int,
@@ -214,6 +216,8 @@ def step_cuda(
"""
TODO(gongshaotian): normalization name
"""
if speculative_config.method is not None:
if enable_prefix_caching:
speculate_step_system_cache(
@@ -291,6 +295,33 @@ def step_cuda(
share_inputs["input_ids"], share_inputs["pre_ids"],
share_inputs["step_idx"], share_inputs["next_tokens"],
share_inputs["first_token_ids"], block_size, enc_dec_block_num)
elif DISABLE_RECOVER:
step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
else:
step_paddle(
share_inputs["stop_flags"],

Some files were not shown because too many files have changed in this diff Show More