mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-29 22:02:30 +08:00
Compare commits
31 Commits
v2.0.0
...
release/2.
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3ec126dc02 | ||
![]() |
337d76f094 | ||
![]() |
ae2f78184d | ||
![]() |
6851489425 | ||
![]() |
ea787d8f62 | ||
![]() |
90ef28d982 | ||
![]() |
b37585e693 | ||
![]() |
9cb08e71e8 | ||
![]() |
dacc46f04c | ||
![]() |
09ded7715f | ||
![]() |
11cfdf5d89 | ||
![]() |
e7fa57ebae | ||
![]() |
a5ae88ded9 | ||
![]() |
87e638498c | ||
![]() |
667547be59 | ||
![]() |
b38823bc66 | ||
![]() |
050d9658a5 | ||
![]() |
be5cabaf80 | ||
![]() |
240bdac2a4 | ||
![]() |
00863c43fd | ||
![]() |
3d3bccdf79 | ||
![]() |
9fd74f75bd | ||
![]() |
05c670e593 | ||
![]() |
d222248d00 | ||
![]() |
e5b94d4117 | ||
![]() |
87e2e58a22 | ||
![]() |
de20e5a992 | ||
![]() |
2f9c0618f0 | ||
![]() |
9a14ab6572 | ||
![]() |
d1cb3ed571 | ||
![]() |
b8a8a19689 |
83
.github/workflows/ci_xpu.yml
vendored
Normal file
83
.github/workflows/ci_xpu.yml
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
name: CI_XPU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-xpu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, XPU-P800-8Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd3:/ssd3" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_xpu.sh
|
||||
"
|
6
.github/workflows/gh-pages.yml
vendored
6
.github/workflows/gh-pages.yml
vendored
@@ -3,8 +3,6 @@ name: Deploy GitHub Pages
|
||||
on:
|
||||
push:
|
||||
branches: [ develop ]
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -21,4 +19,6 @@ jobs:
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: mkdocs gh-deploy --force --remote-name origin
|
||||
run: |
|
||||
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
|
||||
mkdocs gh-deploy --force --remote-name origin
|
||||
|
@@ -5,12 +5,6 @@ default_stages:
|
||||
- pre-commit # Run locally
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
@@ -29,15 +23,6 @@ repos:
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
|
1180
benchmarks/quick_benchmark.py
Normal file
1180
benchmarks/quick_benchmark.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,3 +3,4 @@ tqdm
|
||||
numpy
|
||||
Pillow
|
||||
pyyaml
|
||||
requests
|
||||
|
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
metadata:
|
||||
min_tokens: 32
|
||||
max_tokens: 33
|
19
build.sh
19
build.sh
@@ -166,7 +166,7 @@ function build_and_install() {
|
||||
|
||||
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
|
||||
cd $DIST_DIR
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
|
||||
if [ $? -ne 0 ]; then
|
||||
cd ..
|
||||
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
|
||||
@@ -176,6 +176,21 @@ function build_and_install() {
|
||||
cd ..
|
||||
}
|
||||
|
||||
function version_info() {
|
||||
output_file="fastdeploy/version.txt"
|
||||
fastdeploy_git_commit_id=$(git rev-parse HEAD)
|
||||
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
|
||||
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
|
||||
|
||||
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
|
||||
echo "Paddle version: $paddle_version" >> $output_file
|
||||
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
|
||||
echo "CUDA version: $cuda_version" >> $output_file
|
||||
echo "CXX compiler version: $cxx_version" >> $output_file
|
||||
}
|
||||
|
||||
function cleanup() {
|
||||
rm -rf $BUILD_DIR $EGG_DIR
|
||||
if [ `${python} -m pip list | grep fastdeploy | wc -l` -gt 0 ]; then
|
||||
@@ -207,6 +222,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
set -e
|
||||
|
||||
init
|
||||
version_info
|
||||
build_and_install_ops
|
||||
build_and_install
|
||||
cleanup
|
||||
@@ -237,6 +253,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
else
|
||||
init
|
||||
build_and_install_ops
|
||||
version_info
|
||||
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
fi
|
||||
|
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "multi_head_latent_attention_kernel.h"
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
T m;
|
||||
T d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_t() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
T other_m,
|
||||
T other_d) {
|
||||
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
|
||||
T m_prev = m, d_prev = d;
|
||||
m = m_prev > other_m ? m_prev : other_m;
|
||||
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
|
||||
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1 + other_o[i] * scale2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
|
||||
struct softmax_state_ts {
|
||||
uint32_t num_tiles_ = num_tiles;
|
||||
AlignedVector<T, vec_size> o[num_tiles];
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_ts() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; i++) {
|
||||
o[tile_id][i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
|
||||
__device__ __forceinline__ void produce_kv(CacheT *smem,
|
||||
CacheT *kv_base_gptr,
|
||||
const int * block_table_smem,
|
||||
const uint32_t seq_offset_gmem,
|
||||
const uint32_t seq_offset_smem,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t tidx,
|
||||
const uint32_t chunk_start,
|
||||
const uint32_t chunk_end) {
|
||||
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
|
||||
if (block_id < 0) {
|
||||
block_id = 0;
|
||||
}
|
||||
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
|
||||
// 8/16 T/int8 each time
|
||||
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
|
||||
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
|
||||
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
|
||||
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
|
||||
seq_offset_gmem < chunk_end
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
|
||||
const CacheT* k_smem,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
const uint32_t gid,
|
||||
const float scale,
|
||||
float *s,
|
||||
softmax_state_ts<vec_size, T, num_tile_v>& st) {
|
||||
const CacheT* smem;
|
||||
AlignedVector<T, vec_size> q_vec;
|
||||
AlignedVector<T, vec_size> k_vec;
|
||||
float m_prev = st.m;
|
||||
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
|
||||
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
if (iter_base + j < iter_bound) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = 0.f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = 0.f;
|
||||
}
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
|
||||
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
|
||||
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
st.m = st.m > s[j] ? st.m : s[j];
|
||||
}
|
||||
|
||||
// T o_scale = hexp(m_prev - st.m);
|
||||
float o_scale = __expf(m_prev - st.m);
|
||||
st.d *= o_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
// s[j] = hexp(s[j] - st.m);
|
||||
s[j] = __expf(s[j] - st.m);
|
||||
st.d += s[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
st.o[tile_id][i] *= o_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_sv(const float *s,
|
||||
const CacheT *base_v_smem,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
softmax_state_ts<vec_size, T, num_tile>& st) {
|
||||
const CacheT* v_smem;
|
||||
AlignedVector<T, vec_size> v_vec;
|
||||
#pragma unroll
|
||||
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
|
||||
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
|
||||
uint32_t tile_id = vid / bdx;
|
||||
#pragma unroll
|
||||
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
|
||||
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
@@ -0,0 +1,560 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decode_attention_func.cuh"
|
||||
|
||||
#define CHECK(call) \
|
||||
do \
|
||||
{ \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) \
|
||||
{ \
|
||||
printf("CUDA Error:\n"); \
|
||||
printf(" File: %s\n", __FILE__); \
|
||||
printf(" Line %d:\n", __LINE__); \
|
||||
printf(" Error code:%d\n", error_code); \
|
||||
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
}while(0)
|
||||
|
||||
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
|
||||
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const float in_scale,
|
||||
const int num_chunks,
|
||||
const int chunk_size,
|
||||
const int max_seq_len,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int qid = blockIdx.x, hid = blockIdx.y;
|
||||
const int seq_len_q = seq_lens_q[qid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[qid];
|
||||
if (seq_len_kv == 0) return;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
|
||||
return;
|
||||
}
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
T m;
|
||||
T d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
// merge per ty
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
T m_prev = m;
|
||||
T d_prev = d;
|
||||
const T m_now = multi_m[offset];
|
||||
const T d_now = multi_d[offset];
|
||||
m = m_prev > m_now ? m_prev : m_now;
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma once
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
|
||||
// merge bdy
|
||||
softmax_state_t<vec_size, T> st{};
|
||||
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
|
||||
#pragma once
|
||||
for (int i = 0; i < iter_num; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
AlignedVector<OutT, vec_size> out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = static_cast<OutT>(st.o[i]);
|
||||
}
|
||||
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
|
||||
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
|
||||
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
|
||||
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||
CacheT * __restrict__ cache_v,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
|
||||
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
|
||||
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
|
||||
OutT * __restrict__ out) {
|
||||
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t bid = bidx, gid = threadIdx.y;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
|
||||
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
|
||||
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
|
||||
|
||||
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
|
||||
|
||||
const int *block_table_now = block_table + bid * max_block_num_per_seq;
|
||||
|
||||
const uint32_t num_chunks = gridDim.y;
|
||||
const uint32_t chunk_id = blockIdx.y;
|
||||
const uint32_t q_len = seq_lens_q[bid];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
|
||||
if (kv_len <= 0) {
|
||||
return;
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
|
||||
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
|
||||
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
|
||||
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
using VecT = AlignedVector<T, VEC_SIZE>;
|
||||
VecT q_vec;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
|
||||
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
|
||||
q_vec[i] *= scale;
|
||||
}
|
||||
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
|
||||
}
|
||||
|
||||
|
||||
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
|
||||
uint32_t stage_idx = 0;
|
||||
constexpr int loop_times = DEAL_EACH_TIME / bdy;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
|
||||
const uint32_t k_seq_id = chunk_start + k_seq_offset;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
k_seq_id,
|
||||
k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
|
||||
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
|
||||
float s[DEAL_EACH_TIME];
|
||||
|
||||
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
wait_group<NUM_STAGES - 1>();
|
||||
__syncthreads();
|
||||
// compute qk
|
||||
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
|
||||
cu_q_smem,
|
||||
kv_smem,
|
||||
chunk_start + iter * DEAL_EACH_TIME,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
gid,
|
||||
scale,
|
||||
s,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
// compute sv
|
||||
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
|
||||
s,
|
||||
kv_smem,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = j * bdy + gid;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
|
||||
stage_idx * DEAL_EACH_TIME + k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// normize if not partition_kv
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
|
||||
const uint32_t tile_id = vid / bdx;
|
||||
if (!partition_kv || num_chunk_this_seq == 1) {
|
||||
st.normalize(tile_id);
|
||||
}
|
||||
if (partition_kv && num_chunk_this_seq > 1) {
|
||||
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
|
||||
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
tmp_m[head_idx] = st.m;
|
||||
tmp_d[head_idx] = st.d;
|
||||
} else {
|
||||
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
||||
void MultiQueryDecoderAttention(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
cudaStream_t &stream,
|
||||
const paddle::Tensor &q,
|
||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float rope_scale,
|
||||
const float rope_theta,
|
||||
const float softmax_scale,
|
||||
const float in_scale,
|
||||
paddle::Tensor *out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
constexpr int num_stages = NUM_STAGE;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
|
||||
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
|
||||
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
|
||||
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
|
||||
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
|
||||
|
||||
constexpr int blocky = GROUP_SIZE;
|
||||
const int gridx = bsz;
|
||||
|
||||
constexpr int num_threads = blockx * blocky;
|
||||
|
||||
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
|
||||
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
|
||||
uint32_t cache_smem_bytes = 0;
|
||||
|
||||
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
|
||||
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
|
||||
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
|
||||
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
|
||||
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
|
||||
|
||||
dim3 grids(gridx, num_chunks, kv_num_heads);
|
||||
dim3 blocks(blockx, blocky);
|
||||
if (num_chunks <= 1) {
|
||||
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
|
||||
cache_vec_size, blockx, blocky>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
} else {
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
|
||||
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
|
||||
constexpr int mblockx = HEAD_DIM_V / vec_size;
|
||||
constexpr int bdy = 256 / mblockx;
|
||||
dim3 grids_merge(bsz, num_heads);
|
||||
dim3 blocks_merge(mblockx, bdy);
|
||||
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
in_scale,
|
||||
num_chunks,
|
||||
chunk_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
}
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim_qk = meta_data.head_dims;
|
||||
const auto head_dim_v = meta_data.head_dims_v;
|
||||
const float rope_scale = 0.0;
|
||||
const float rope_theta = 0.0;
|
||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||
|
||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "mla_cache_kernel.cuh"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
|
||||
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
} else {
|
||||
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_OP(decode_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
|
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
@@ -0,0 +1,242 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
|
||||
int kv_num_heads;
|
||||
int token_nums;
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
};
|
||||
|
||||
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 576: { \
|
||||
constexpr size_t HEAD_DIM = 576; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_stage: ", num_stage); \
|
||||
}
|
||||
|
||||
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
||||
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
||||
constexpr size_t cache_bytes = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
if (block_shape_q <= 16) { \
|
||||
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
||||
|
@@ -316,6 +316,96 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
|
||||
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts);
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch);
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox);
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
|
||||
@@ -370,6 +460,14 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
|
||||
paddle::Tensor const &input,
|
||||
paddle::Tensor &scales, float scale_ub);
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(
|
||||
paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -627,6 +725,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"),
|
||||
py::arg("is_zp_float"));
|
||||
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
||||
"get_position_ids_and_mask_encoder_batch function");
|
||||
|
||||
|
||||
/**
|
||||
@@ -653,4 +753,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||
"dynamic_per_token_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
||||
|
||||
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
||||
|
||||
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
|
||||
|
||||
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
||||
|
||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||
}
|
||||
|
64
custom_ops/gpu_ops/env.h
Normal file
64
custom_ops/gpu_ops/env.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
inline uint32_t get_decoder_block_shape_q() {
|
||||
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
|
||||
static const uint32_t decoder_block_shape_q =
|
||||
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
|
||||
return decoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_encoder_block_shape_q() {
|
||||
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
|
||||
static const uint32_t encoder_block_shape_q =
|
||||
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
|
||||
return encoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_max_partition_size(int bsz) {
|
||||
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
|
||||
static const uint32_t max_partition_size =
|
||||
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
|
||||
return max_partition_size;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_deal_each_time() {
|
||||
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
|
||||
static const uint32_t cascade_attention_deal_each_time =
|
||||
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
|
||||
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_stages() {
|
||||
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
|
||||
static const uint32_t cascade_attention_num_stages =
|
||||
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
|
||||
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_threads() {
|
||||
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
|
||||
static const uint32_t cascade_attention_num_threads =
|
||||
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
|
||||
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
|
||||
}
|
||||
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding_kernel(
|
||||
T* __restrict__ arr,
|
||||
const T* __restrict__ cos_ptr,
|
||||
const T* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim) {
|
||||
int x_index, y_index;
|
||||
T cos, sin;
|
||||
if (IS_NEOX) {
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = cos_ptr[x_index];
|
||||
sin = sin_ptr[x_index];
|
||||
} else {
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = cos_ptr[x_index / 2];
|
||||
sin = sin_ptr[x_index / 2];
|
||||
}
|
||||
|
||||
const T x = arr[x_index];
|
||||
const T y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
__global__ void apply_rotary_embedding_kernel(
|
||||
T* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const int* __restrict__ position_ids, // [num_tokens]
|
||||
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int pos = position_ids[token_idx];
|
||||
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const T* cos_ptr = cache_ptr;
|
||||
const T* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.dims()[0];
|
||||
int num_heads = query.numel() / num_tokens / head_size;
|
||||
int num_kv_heads = key.numel() / num_tokens / head_size;
|
||||
int rot_dim = cos_sin_cache.dims()[1];
|
||||
int64_t query_stride = num_heads * head_size;
|
||||
int64_t key_stride = num_kv_heads * head_size;
|
||||
|
||||
if (num_tokens > 65535) {
|
||||
PD_THROW(
|
||||
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
||||
query.dtype(), "apply_rotary_embedding_kernel", [&] {
|
||||
if (is_neox) {
|
||||
apply_rotary_embedding_kernel<data_t, true>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
apply_rotary_embedding_kernel<data_t, false>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
PD_BUILD_OP(fused_rotary_position_encoding)
|
||||
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
||||
.Outputs({"query_out", "key_out"})
|
||||
.Attrs({"head_size: int", "is_neox: bool"})
|
||||
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
|
||||
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));
|
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
||||
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
|
||||
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
||||
const int* seq_lens_this_time,
|
||||
int* position_ids, // 输出的一维 position_ids
|
||||
int* mask_encoder_batch,
|
||||
const int bsz) { // 批次大小
|
||||
// 当前线程索引(每个线程对应一个批次)
|
||||
int tid = threadIdx.x;
|
||||
if (tid >= bsz) return;
|
||||
|
||||
// 动态计算当前批次的偏移量
|
||||
int offset = 0;
|
||||
for (int i = 0; i < tid; i++) {
|
||||
offset += seq_lens_encoder[i];
|
||||
if (seq_lens_decoder[i] > 0) {
|
||||
offset += seq_lens_this_time[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 当前批次的 encoder 和 decoder 长度
|
||||
int encoder_len = seq_lens_encoder[tid];
|
||||
int decoder_len = seq_lens_decoder[tid];
|
||||
int seq_len_this_time = seq_lens_this_time[tid];
|
||||
|
||||
// 写入 encoder 的 position_ids
|
||||
for (int i = 0; i < encoder_len; i++) {
|
||||
position_ids[offset + i] = i;
|
||||
mask_encoder_batch[offset + i] = 1;
|
||||
}
|
||||
offset += encoder_len;
|
||||
|
||||
// 写入 decoder 的 position_ids
|
||||
if (decoder_len > 0) {
|
||||
for (int i = 0; i < seq_len_this_time; i++) {
|
||||
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
|
||||
mask_encoder_batch[offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch) {
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
const_cast<int*>(position_ids.data<int>()),
|
||||
const_cast<int*>(mask_encoder_batch.data<int>()),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"position_ids",
|
||||
"mask_encoder_batch"})
|
||||
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
|
||||
.SetInplaceMap({{"position_ids", "position_ids_out"},
|
||||
{"mask_encoder_batch", "mask_encoder_batch_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
|
@@ -39,10 +39,12 @@ namespace cub = hipcub;
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "env.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -513,3 +515,10 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
inline int GetSMVersion() {
|
||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||
phi::backends::gpu::GetCurrentDeviceId());
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/detail/helper_macros.hpp>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||
};
|
||||
|
||||
template <int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Allreduce<2> {
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
|
||||
Tensor<Engine1, Layout1>& src, Operator& op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++) {
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
thread_reduce_<init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& max) {
|
||||
MaxOp<float> max_op;
|
||||
reduce_<init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
|
||||
typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& sum) {
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) {
|
||||
quad_allreduce_(sum, sum, sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max,
|
||||
const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// row_max * scale is a constant for each row, so we can use fma here
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
|
||||
struct OnlineSoftmax {
|
||||
constexpr static float fill_value = -5e4;
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
|
||||
TensorT row_max, row_sum, scores_scale;
|
||||
float sm_scale_log2;
|
||||
|
||||
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
|
||||
clear(scores_scale);
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
|
||||
|
||||
template <bool init, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
if constexpr (init) {
|
||||
reduce_max</*init=*/true>(scores, row_max);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
} else {
|
||||
// update row_max
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*init=*/false>(scores, row_max);
|
||||
// update scores_scale and scale row_sum
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
|
||||
} else {
|
||||
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
|
||||
}
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
// perform exp2 on scores
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
// update row_sum
|
||||
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
return scores_scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor0>
|
||||
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.f / sum;
|
||||
scores_scale(mi) = inv_sum;
|
||||
row_max(mi) *= sm_scale_log2;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template <typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor1, typename Tensor2>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
235
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
235
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "cute/tensor.hpp"
|
||||
#include "mla_hopper.cuh"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "batch_mla_with_paged_kv_cache.h"
|
||||
#include "env.h"
|
||||
|
||||
using namespace cute;
|
||||
using namespace mla_attn;
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
struct cascade_type_traits {
|
||||
using type = T;
|
||||
using cutlass_type = T;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::bfloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
using cutlass_type = cutlass::bfloat16_t;;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float16> {
|
||||
using type = half;
|
||||
using cutlass_type = cutlass::half_t;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
using cutlass_type = cutlass::float_e4m3_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
using NV_TYPE = typename cascade_type_traits<T>::type;
|
||||
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
O_tmp = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
|
||||
m_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
d_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
|
||||
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
|
||||
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
|
||||
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
|
||||
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
|
||||
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
|
||||
params.m = reinterpret_cast<float*>(m_tmp->ptr());
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
params.kv_stride_block_size = k_head_dim;
|
||||
params.o_stride_bsz = q_head_num * v_head_dim;
|
||||
params.o_stride_head_num = v_head_dim;
|
||||
params.bsz = bsz;
|
||||
params.token_num = token_num;
|
||||
params.max_seq_len = max_seq_len;
|
||||
params.max_block_num = max_block_num;
|
||||
params.max_block_num_per_seq = max_block_num_per_seq;
|
||||
params.q_num_head = q_head_num;
|
||||
params.qk_head_dim = q_head_dim;
|
||||
params.vo_head_dim = v_head_dim;
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
|
||||
params, stream
|
||||
);
|
||||
} else {
|
||||
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
|
||||
}
|
||||
}
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
69
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
69
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "append_attn/utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
|
||||
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
#define ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveEpilogue {
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
|
||||
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
|
||||
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
|
||||
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
|
||||
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
|
||||
|
||||
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideNTMAT = cute::Shape<int32_t, _1>;
|
||||
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
|
||||
|
||||
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
||||
using TMA_O = decltype(make_tma_copy(
|
||||
GmemTiledCopyOTMA{},
|
||||
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
|
||||
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
|
||||
|
||||
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
|
||||
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
|
||||
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
|
||||
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
|
||||
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
|
||||
using TiledCopyOValLayout =
|
||||
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
|
||||
using TiledCopyO =
|
||||
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
|
||||
TiledCopyOValLayout{} // Val layout
|
||||
));
|
||||
struct Arguments {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments_ntma(Arguments const& args) {
|
||||
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
|
||||
typename TiledMma>
|
||||
CUTLASS_DEVICE void store(Params const& epilogue_params,
|
||||
FrgTensorO const& tOrO,
|
||||
FrgTensorLSE const& lse,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int bsz,
|
||||
const int seq_len_now,
|
||||
const int start_token_idx,
|
||||
const int tile_idx,
|
||||
const int kv_len,
|
||||
const int chunk_size,
|
||||
const int max_draft_token_num,
|
||||
const int o_stride_bsz) {
|
||||
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
// make sure gemm done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
// r2s
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
// make sure r2s done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
TiledCopyO gmem_tiled_copy_O;
|
||||
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
|
||||
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
|
||||
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
|
||||
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
|
||||
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
|
||||
|
||||
// copy if not out of bound
|
||||
auto predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tOcOGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
|
||||
};
|
||||
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_
|
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
|
||||
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
|
||||
struct alignas(16) SharedStorageQKVO {
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
|
||||
union {
|
||||
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
|
||||
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
|
||||
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
|
||||
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
|
||||
struct AttentionKernelTraits {
|
||||
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
using DTypeQKAccum = float;
|
||||
using DTypePVAccum = float;
|
||||
using NV_TYPE = NV_TYPE_;
|
||||
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
|
||||
static constexpr int GROUP_SIZE = GROUP_SIZE_;
|
||||
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
|
||||
static_assert(BLOCK_SHAPE_Q % 64 == 0);
|
||||
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
|
||||
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
|
||||
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
|
||||
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
|
||||
static_assert(HEAD_DIM_QK % 32 == 0);
|
||||
static_assert(HEAD_DIM_VO % 32 == 0);
|
||||
|
||||
static constexpr int NUM_WARPS = 12;
|
||||
static constexpr int NUM_THREADS = 384;
|
||||
static constexpr int NUM_PRODUCER_THREADS = 128;
|
||||
|
||||
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_STAGES = NUM_STAGES_;
|
||||
|
||||
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
|
||||
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
|
||||
using TiledMmaQK = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
|
||||
using TiledMmaPV = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
|
||||
using SmemLayoutVt = decltype(composition(
|
||||
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
|
||||
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomV{},
|
||||
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
|
||||
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVtOneStage = decltype(composition(
|
||||
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
|
||||
get<2>(TileShape_PDV{}), Int<1>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
|
||||
|
||||
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
|
||||
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
|
||||
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
|
||||
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
|
||||
|
||||
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<1>(TileShape_QKD{}))>());
|
||||
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
|
||||
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
|
||||
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
|
||||
|
||||
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<1>;
|
||||
using MainloopPipeline =
|
||||
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
|
||||
typename cutlass::PipelineAsync<NUM_STAGES>>;
|
||||
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
|
||||
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif
|
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
@@ -0,0 +1,348 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits, bool CAUSAL>
|
||||
struct CollectiveMainloop {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = float;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
|
||||
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
|
||||
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
|
||||
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
|
||||
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
|
||||
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
|
||||
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
|
||||
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomQ = Layout<
|
||||
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyQ = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtomQ{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
|
||||
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
|
||||
using ShapeQT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideQT = cute::Shape<int32_t, _1>;
|
||||
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeMDT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideMDT = cute::Shape<int32_t, _1>;
|
||||
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)), StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_QKD{}),
|
||||
_1{})); // no mcast for KV
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesQ =
|
||||
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesKV =
|
||||
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ const* Q_ptr;
|
||||
DTypeKV const* KV_ptr;
|
||||
DTypeMD const* m_ptr;
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ *Q_ptr;
|
||||
DTypeKV* KV_ptr;
|
||||
DTypeMD* m_ptr;
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args) {
|
||||
TMA_KV tma_load_KV;
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
|
||||
tma_load_KV =
|
||||
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
|
||||
}
|
||||
return {args.layout_Q,
|
||||
args.layout_KV,
|
||||
args.layout_MD,
|
||||
const_cast<DTypeQ*>(args.Q_ptr),
|
||||
const_cast<DTypeKV*>(args.KV_ptr),
|
||||
const_cast<DTypeMD*>(args.m_ptr),
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
args.max_block_num_per_seq,
|
||||
args.q_stride_bsz,
|
||||
args.q_stride_head_num,
|
||||
args.kv_stride_block_num,
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& shared_storage,
|
||||
const int thread_idx,
|
||||
const int bid) {
|
||||
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
|
||||
Tensor gQ =
|
||||
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor cQ = cute::make_identity_tensor(gQ.shape());
|
||||
|
||||
GmemTiledCopyQ gmem_tiled_copy_q;
|
||||
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
|
||||
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
|
||||
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
|
||||
Tensor tQcQGroup = flatten_1(tQcQ);
|
||||
|
||||
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
|
||||
auto q_predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tQcQGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
|
||||
};
|
||||
Tensor tQgQiGroup = flatten_1(tQgQ);
|
||||
Tensor tQsQiGroup = flatten_1(tQsQ);
|
||||
|
||||
pipeline_q.producer_acquire(smem_pipe_write_q);
|
||||
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
|
||||
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_q;
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
GmemTiledCopy gmem_tiled_copy_kv;
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
|
||||
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
|
||||
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
|
||||
Tensor tKsKiGroup =
|
||||
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
|
||||
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
|
||||
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
|
||||
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
|
||||
|
||||
// Prepare the TMA loads
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
auto [tKgK, tKsK] =
|
||||
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
|
||||
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_
|
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include "named_barrier.cuh"
|
||||
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
|
||||
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
bool is_first_step = true;
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// gemm qk
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
|
||||
is_first_step = false;
|
||||
|
||||
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure r2s all done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
|
||||
// pv gemm
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
|
||||
// normalize
|
||||
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
// wait k
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// first qk gemm
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
{
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
|
||||
--kv_tile_idx;
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
|
||||
++smem_pipe_read_kv;
|
||||
// wait next kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
|
||||
// gemm next qk
|
||||
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
attention_updater.rescale_o(tOrO);
|
||||
// last pv gemm
|
||||
if (smem_pipe_read_kv_cur.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
// wait cur qk gemm
|
||||
warpgroup_wait<1>();
|
||||
// mask p
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure tSrS r2s done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// wait last pv gemm
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
// compute last pv
|
||||
attention_updater.rescale_o(tOrO);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
scale_o = attention_updater.finalize(tSrS);
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
attention_updater.rescale_o(tOrO);
|
||||
}
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
@@ -0,0 +1,575 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "attention_updater.cuh"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "epilogue.cuh"
|
||||
#include "helper.h"
|
||||
#include "kernel_traits.cuh"
|
||||
#include "mainloop_mma.cuh"
|
||||
#include "mainloop_load.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
||||
struct Params {
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *padding_offsets;
|
||||
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
|
||||
uint32_t kv_stride_block_num;
|
||||
uint32_t kv_stride_block_size;
|
||||
|
||||
uint32_t o_stride_bsz;
|
||||
uint32_t o_stride_head_num;
|
||||
|
||||
int bsz;
|
||||
int token_num;
|
||||
int max_seq_len;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_num_head;
|
||||
int qk_head_dim;
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
|
||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr size_t GROUP_SIZE = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
||||
} else {
|
||||
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
|
||||
}
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||
|
||||
|
||||
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
||||
MainloopPipeline pipeline_kv = [&] {
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
||||
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
||||
} else {
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
||||
}
|
||||
}();
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue;
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
// producer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<72>();
|
||||
}
|
||||
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// consumer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_alloc<216>();
|
||||
}
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineState smem_pipe_read_kv;
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
using DTypeKV = typename KernelTraits::DTypeKV;
|
||||
using DTypeO = typename KernelTraits::DTypeO;
|
||||
using IdType = typename KernelTraits::IdType;
|
||||
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
||||
|
||||
using CollectiveMainloop =
|
||||
CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
||||
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
||||
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
||||
params.Q,
|
||||
params.KV,
|
||||
params.m,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
params.max_block_num_per_seq,
|
||||
params.q_stride_bsz,
|
||||
params.q_stride_head_num,
|
||||
params.kv_stride_block_num,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||
params.O,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
||||
params.O_tmp,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||
});
|
||||
|
||||
// Get the ptr to kernel function.
|
||||
auto kernel =
|
||||
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
||||
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiprocessor_count;
|
||||
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
int act_blocks_per_sm;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
mainloop_params, epilogue_params
|
||||
);
|
||||
if (params.chunk_num > 1) {
|
||||
constexpr int vec_size = 16 / sizeof(DTypeO);
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.padding_offsets,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.max_seq_len,
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
/*BLOCK_SHAPE_Q_=*/64,
|
||||
/*BLOCK_SHAPE_KV_=*/64,
|
||||
/*NUM_STAGES_=*/2,
|
||||
typename Params::DTypeQ,
|
||||
typename Params::DTypeKV,
|
||||
typename Params::DTypeO,
|
||||
typename Params::IdType,
|
||||
NV_TYPE>,
|
||||
CAUSAL,
|
||||
Params,
|
||||
USE_REG_EALLOC,
|
||||
USE_FIXED_BLOCK>(params, stream);)
|
||||
} else {
|
||||
return cudaErrorNotSupported;
|
||||
}
|
||||
return cudaSuccess;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
enum class NamedBarriers {
|
||||
kQueryEmpty = 0,
|
||||
kValueEmpty = 1,
|
||||
kWarpSchedulerWG1 = 2,
|
||||
kWarpSchedulerWG2 = 3,
|
||||
kWarpSchedulerWG3 = 4,
|
||||
kPrefetchIndices = 5,
|
||||
kOdone = 6,
|
||||
kWG1WG2Sync = 7,
|
||||
kWG0WG1WG2Sync = 8,
|
||||
kWG1WG2LastSync = 9,
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
@@ -0,0 +1,351 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_UTILS_CUH_
|
||||
#define ATTENTION_HOPPER_UTILS_CUH_
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename TensorT>
|
||||
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
|
||||
Tensor tensor_flatten = cute::flatten(tensor);
|
||||
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
|
||||
int64_t h_stride) {
|
||||
return make_layout(make_shape(nnz, head_dim, num_heads),
|
||||
make_stride(n_stride, cute::_1{}, h_stride));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
|
||||
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(offset, _0{}));
|
||||
auto g_sequence =
|
||||
make_tensor(g_offset.data(),
|
||||
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
|
||||
|
||||
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
|
||||
cute::make_shape(shape<0>(m_tensor))));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
|
||||
// MMA_N))
|
||||
template <typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
|
||||
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
|
||||
// MMA_N))
|
||||
template <typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
|
||||
make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
|
||||
TensorC& tCrC) {
|
||||
constexpr bool Is_RS =
|
||||
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
warpgroup_arrive();
|
||||
if constexpr (init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
if constexpr (wg_wait >= 0) {
|
||||
warpgroup_wait<wg_wait>();
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
}
|
||||
|
||||
#define HOSTDEVICE __host__ __device__
|
||||
|
||||
template <typename T, int Size>
|
||||
struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
|
||||
const AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
|
||||
*vec = *addr_vec;
|
||||
}
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
|
||||
AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size>*>(addr);
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct prefill_softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
const float other_m,
|
||||
const float other_d) {
|
||||
float m_prev = m, d_prev = d;
|
||||
m = max(m_prev, other_m);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
const T d_t = static_cast<T>(d);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int * __restrict__ padding_offsets,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int max_seq_len,
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
const int max_draft_token_num) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
}
|
||||
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
float m;
|
||||
float d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
m = -3.0e+30f;
|
||||
}
|
||||
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset;
|
||||
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
|
||||
float m_prev = m;
|
||||
float d_prev = d;
|
||||
const float m_now = multi_m[offset];
|
||||
const float d_now = multi_d[offset];
|
||||
m = max(m_prev, m_now);
|
||||
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
if (ty == 0) {
|
||||
// merge bdy
|
||||
prefill_softmax_state_t<vec_size, T> st;
|
||||
st.init();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < bdy; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_UTILS_CUH_
|
@@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
469
custom_ops/gpu_ops/multi_head_latent_attention.cu
Normal file
469
custom_ops/gpu_ops/multi_head_latent_attention.cu
Normal file
@@ -0,0 +1,469 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "append_attn/multi_head_latent_attention_kernel.h"
|
||||
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) {
|
||||
PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90.");
|
||||
}
|
||||
|
||||
auto main_stream = query.stream();
|
||||
|
||||
paddle::Tensor fmha_out = paddle::full(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
|
||||
0,
|
||||
D,
|
||||
query.place());
|
||||
|
||||
if (max_dec_len_this_time_data > 0) {
|
||||
if (mla_use_tensorcore) {
|
||||
BatchMLAWithPagedKVCacheKernel<data_t>(meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_num_blocks_data,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
main_stream,
|
||||
&fmha_out);
|
||||
} else {
|
||||
DecodeMLAAttentionKernel<data_t>(
|
||||
meta_data,
|
||||
query, // [token_num, num_heads, head_dim]
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time, // q_seq_len is 1
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
out_linear_in_scale,
|
||||
causal,
|
||||
main_stream,
|
||||
&fmha_out);
|
||||
}
|
||||
}
|
||||
return {fmha_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& query_dims = query.dims();
|
||||
const auto& key_cache_dims = key_cache.dims();
|
||||
const int q_hidden_size = query_dims[query_dims.size() - 1];
|
||||
meta_data.token_nums = query_dims[0];
|
||||
meta_data.kv_num_heads = key_cache_dims[1];
|
||||
meta_data.head_dims = key_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
|
||||
switch (query.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return MultiHeadLatentAttentionKernel<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
query_bias,
|
||||
query_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
cache_quant_type_str,
|
||||
max_input_length,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return MultiHeadLatentAttentionKernel<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
query_bias,
|
||||
query_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
cache_quant_type_str,
|
||||
max_input_length,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& query_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& query_out_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const int token_num = query_shape[0];
|
||||
const int kv_num_heads = key_cache_shape[1];
|
||||
const int head_dim_qk = key_cache_shape[3];
|
||||
const int head_dim_v = nope_size;
|
||||
const int q_hidden_size = query_shape[query_shape.size() - 1];
|
||||
const int num_heads = q_hidden_size / head_dim_qk;
|
||||
return {{token_num, num_heads * head_dim_v}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& query_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& query_out_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
if (compute_dtype == "bf16") {
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
} else if (compute_dtype == "fp16") {
|
||||
return {paddle::DataType::FLOAT16};
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("query_bias"),
|
||||
paddle::Optional("query_out_scales"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
paddle::Optional("cache_v_dequant_scales"),
|
||||
paddle::Optional("cache_k_zp"),
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths")})
|
||||
.Outputs({"fmha_out"})
|
||||
.Attrs({"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"nope_size: int",
|
||||
"max_input_length: int",
|
||||
"softmax_scale: float",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(MultiHeadLatentAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype));
|
73
custom_ops/gpu_ops/noaux_tc.cu
Normal file
73
custom_ops/gpu_ops/noaux_tc.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "helper.h"
|
||||
#include "noauxtc_kernel.h"
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
int64_t num_tokens = input_shape[0];
|
||||
int64_t num_experts = input_shape[1];
|
||||
auto input_type = scores_with_bias.dtype();
|
||||
auto place = scores_with_bias.place();
|
||||
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
||||
auto stream = scores_with_bias.stream();
|
||||
|
||||
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
|
||||
reinterpret_cast<float*>(group_scores.data<float>()),
|
||||
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
return {scores};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> NoauxTcInferDtype(
|
||||
const paddle::DataType& scores_dtype,
|
||||
const paddle::DataType& scores_with_bias_dtype) {
|
||||
return {scores_dtype};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
||||
const std::vector<int64_t>& scores_shape,
|
||||
const std::vector<int64_t>& gating_output_shape) {
|
||||
return {scores_shape};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(noaux_tc)
|
||||
.Inputs({"scores", "scores_with_bias"})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype));
|
551
custom_ops/gpu_ops/noauxtc_kernel.h
Normal file
551
custom_ops/gpu_ops/noauxtc_kernel.h
Normal file
@@ -0,0 +1,551 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// This code is partially inspired by and references the implementation found
|
||||
// in NVIDIA TRTLLM.
|
||||
#pragma once
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
return ((len - 1) / size + 1) * size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
return (v && !(v & (v - 1)));
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
constexpr int stride = arr_len / 2;
|
||||
for (int i = 0; i < stride; ++i) {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
|
||||
idxT tmp2 = idx_arr[i];
|
||||
idx_arr[i] = idx_arr[other_i];
|
||||
idx_arr[other_i] = tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// ascending doesn't matter before merging since all we need is a bitonic
|
||||
// sequence
|
||||
for (int stage = 0; stage < 4; ++stage) {
|
||||
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
||||
bool reverse = (lane >> stage) & 2;
|
||||
bool is_second = lane & stride;
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
bool is_second = lane & stride;
|
||||
T& val = *val_arr;
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSort {
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out[out_i] = val_arr_[i];
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
idxT idx_arr_[max_arr_len_];
|
||||
|
||||
int const lane_;
|
||||
idxT const k_;
|
||||
T const dummy_;
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
int const num_of_warp = blockDim.x / WARP_SIZE;
|
||||
int const warp_id = threadIdx.x / WARP_SIZE;
|
||||
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
||||
val_smem_ += warp_id * WARP_SIZE;
|
||||
idx_smem_ = reinterpret_cast<idxT*>(
|
||||
smem_buf +
|
||||
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
||||
idx_smem_ += warp_id * WARP_SIZE;
|
||||
}
|
||||
|
||||
__device__ void add(T const* in, idxT start, idxT end) {
|
||||
idxT const end_for_fullwarp =
|
||||
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
||||
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
||||
T val = (i < end) ? in[i] : dummy_;
|
||||
add(val, i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
||||
if (do_add && pos < WARP_SIZE) {
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
do_add = false;
|
||||
}
|
||||
smem_buf_len_ += __popc(mask);
|
||||
if (smem_buf_len_ >= WARP_SIZE) {
|
||||
__syncwarp();
|
||||
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
||||
smem_buf_len_ -= WARP_SIZE;
|
||||
}
|
||||
if (do_add) {
|
||||
pos -= WARP_SIZE;
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ void done() {
|
||||
if (smem_buf_len_) {
|
||||
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output,
|
||||
T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = input[i];
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
largest = value;
|
||||
} else if (value > second_largest) {
|
||||
second_largest = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
largest = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
T max2 = max1;
|
||||
bool equal_to_max1 = (max1 == largest);
|
||||
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
||||
|
||||
if (count_max1 == 1) {
|
||||
largest = (largest == max1) ? second_largest : largest;
|
||||
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*output = max1 + max2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void topk_with_k2_kernel(T* output,
|
||||
T* input,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores,
|
||||
T const* group_scores,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
if (case_id < num_tokens) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = i < topk ? scores[s_topk_idx[i]]
|
||||
: 0.0f; // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeNoAuxTc(T* scores,
|
||||
T* group_scores,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_experts,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T) \
|
||||
template void invokeNoAuxTc<T>(T * scores, \
|
||||
T * group_scores, \
|
||||
T * scores_with_bias, \
|
||||
int64_t const num_tokens, \
|
||||
int64_t const num_experts, \
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float);
|
@@ -50,11 +50,13 @@ __global__ void quant_per_token_per_block(const T *input,
|
||||
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
|
||||
}
|
||||
// get max value per warp
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
// broadcast max_value
|
||||
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
|
||||
max_value_thread = max(max_value_thread, epsilon);
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
// quant
|
||||
|
@@ -267,6 +267,9 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
]
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -376,6 +379,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
# append_attention
|
||||
sources += ["gpu_ops/append_attention.cu"]
|
||||
sources += find_end_files("gpu_ops/append_attn", ".cu")
|
||||
# mla
|
||||
sources += ["gpu_ops/multi_head_latent_attention.cu"]
|
||||
# gemm_dequant
|
||||
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
|
||||
# speculate_decoding
|
||||
@@ -441,6 +446,10 @@ elif paddle.is_compiled_with_cuda():
|
||||
|
||||
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
||||
|
||||
if cc >= 90 and nvcc_version >= 12.0:
|
||||
# Hopper optmized mla
|
||||
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
ext_modules=CUDAExtension(
|
||||
|
@@ -9,11 +9,6 @@ COPY . /workspace/FastDeploy
|
||||
RUN echo "ulimit -u unlimited" >> /root/.bashrc
|
||||
RUN echo "ulimit -n 65536" >> /root/.bashrc
|
||||
|
||||
# setting proxy
|
||||
ARG http_proxy=agent.baidu.com:8891
|
||||
ARG https_proxy=agent.baidu.com:8891
|
||||
ARG no_proxy=localhost,bj.bcebos.com,su.bcebos.com,pypi.tuna.tsinghua.edu.cn,paddle-ci.gz.bcebos.com
|
||||
|
||||
# uninstall existing package
|
||||
RUN python -m pip uninstall paddlepaddle-gpu fastdeploy-gpu -y
|
||||
|
||||
|
@@ -7,11 +7,6 @@ deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe
|
||||
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse \n\
|
||||
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse" > /etc/apt/sources.list
|
||||
|
||||
# setting proxy
|
||||
ENV http_proxy=http://agent.baidu.com:8891
|
||||
ENV https_proxy=http://agent.baidu.com:8891
|
||||
ENV no_proxy=localhost,bj.bcebos.com,su.bcebos.com,pypi.tuna.tsinghua.edu.cn,paddle-ci.gz.bcebos.com
|
||||
|
||||
RUN apt-get update && apt-get install -y libibverbs-dev librdmacm-dev cmake pybind11-dev
|
||||
|
||||
# uninstall existing package
|
||||
@@ -40,4 +35,4 @@ COPY . /workspace/FastDeploy
|
||||
RUN cd /workspace/FastDeploy && bash build.sh && python -m pip install --no-cache-dir dist/* && rm -rf /workspace/FastDeploy
|
||||
|
||||
ENV http_proxy=""
|
||||
ENV https_proxy=""
|
||||
ENV https_proxy=""
|
||||
|
@@ -1,19 +1,19 @@
|
||||
# Chain-of-Thought Content
|
||||
# Reasoning Outputs
|
||||
|
||||
The reasoning model returns a `reasoning_content` field in the output, representing the chain-of-thought content—the reasoning steps that lead to the final conclusion.
|
||||
Reasoning models return an additional `reasoning_content` field in their output, which contains the reasoning steps that led to the final conclusion.
|
||||
|
||||
## Currently Supported Chain-of-Thought Models
|
||||
| Model Name | Parser Name | Chain-of-Thought Enabled by Default |
|
||||
|----------------|----------------|-------------------------------------|
|
||||
| ernie-45-vl | ernie-45-vl | ✓ |
|
||||
| ernie-lite-vl | ernie-45-vl | ✓ |
|
||||
## Supported Models
|
||||
| Model Name | Parser Name | Eable_thinking by Default |
|
||||
|----------------|----------------|---------------------------|
|
||||
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | ernie-45-vl | ✓ |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✓ |
|
||||
|
||||
The reasoning model requires a specified parser to interpret the reasoning content. The reasoning mode can be disabled by setting the `enable_thinking=False` parameter.
|
||||
The reasoning model requires a specified parser to extract reasoning content. The reasoning mode can be disabled by setting the `enable_thinking=False` parameter.
|
||||
|
||||
Interfaces that support toggling the reasoning mode:
|
||||
1. `/v1/chat/completions` request in OpenAI services.
|
||||
2. `/v1/chat/completions` request in the OpenAI Python client.
|
||||
3. `llm.chat` request in Offline interfaces.
|
||||
1. `/v1/chat/completions` requests in OpenAI services.
|
||||
2. `/v1/chat/completions` requests in the OpenAI Python client.
|
||||
3. `llm.chat` requests in Offline interfaces.
|
||||
|
||||
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
|
||||
|
||||
@@ -21,10 +21,15 @@ For reasoning models, the length of the reasoning content can be controlled via
|
||||
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
|
||||
This parser will process the model's output and extract the `reasoning_content` field.
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server --model /root/merge_llm_model --enable-mm --tensor-parallel-size=8 --port 8192 --quantization wint4 --reasoning-parser=ernie-45-vl
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/your/model \
|
||||
--enable-mm \
|
||||
--tensor-parallel-size 8 \
|
||||
--port 8192 \
|
||||
--quantization wint4 \
|
||||
--reasoning-parser ernie-45-vl
|
||||
```
|
||||
|
||||
Next, send a `chat completion` request to the model:
|
||||
Next, make a request to the model that should return the reasoning content in the response.
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -40,8 +45,8 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
||||
```
|
||||
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
|
||||
|
||||
### Streaming Sessions
|
||||
In streaming sessions, the `reasoning_content` field can be retrieved from the `delta` in `chat completion response chunks`.
|
||||
### Streaming chat completions
|
||||
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
|
||||
```python
|
||||
from openai import OpenAI
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
|
@@ -62,7 +62,7 @@ The differences in request parameters between FastDeploy and the OpenAI protocol
|
||||
- `stream_options`: Optional[StreamOptions] = None
|
||||
- `temperature`: Optional[float] = None
|
||||
- `top_p`: Optional[float] = None
|
||||
- `metadata`: Optional[dict] = None (supported only in `v1/chat/completions` for configuring additional parameters, e.g., `meta_data={"enable_thinking": True}`)
|
||||
- `metadata`: Optional[dict] = None (supported only in `v1/chat/completions` for configuring additional parameters, e.g., `metadata={"enable_thinking": True}`)
|
||||
- `min_tokens`: Optional[int] = 1 (minimum number of tokens generated)
|
||||
- `reasoning_max_tokens`: Optional[int] = None (maximum number of tokens for reasoning content, defaults to the same as `max_tokens`)
|
||||
- `enable_thinking`: Optional[bool] = True (whether to enable reasoning for models that support deep thinking)
|
||||
|
@@ -27,7 +27,7 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```kv_cache_ratio``` | `float` | KVCache blocks are divided between Prefill phase and Decode phase according to kv_cache_ratio ratio, default: 0.75 |
|
||||
| ```enable_prefix_caching``` | `bool` | Whether to enable Prefix Caching, default: False |
|
||||
| ```swap_space``` | `float` | When Prefix Caching is enabled, CPU memory size for KVCache swapping, unit: GB, default: None |
|
||||
| ```enable_chunk_prefill``` | `bool` | Enable Chunked Prefill, default: False |
|
||||
| ```enable_chunked_prefill``` | `bool` | Enable Chunked Prefill, default: False |
|
||||
| ```max_num_partial_prefills``` | `int` | When Chunked Prefill is enabled, maximum concurrent number of partial prefill batches, default: 1 |
|
||||
| ```max_long_partial_prefills``` | `int` | When Chunked Prefill is enabled, maximum number of long requests in concurrent partial prefill batches, default: 1 |
|
||||
| ```long_prefill_token_threshold``` | `int` | When Chunked Prefill is enabled, requests with token count exceeding this value are considered long requests, default: max_model_len*0.04 |
|
||||
@@ -115,5 +115,5 @@ FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter
|
||||
...
|
||||
```
|
||||
- When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1.
|
||||
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunk_prefill```.
|
||||
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```.
|
||||
- When ```use_cudagraph``` is enabled, batches with size ≤ ```max_capture_batch_size``` will be executed by CudaGraph, batches > ```max_capture_batch_size``` will be executed by original dynamic/static graph. To have all batch sizes executed by CudaGraph, ```max_capture_batch_size``` value should match ```max_num_seqs```. ```max_capture_batch_size``` > ```max_num_seqs``` will cause waste by capturing batches that won't be encountered during inference, occupying more time and memory.
|
@@ -57,3 +57,6 @@ On the ERNIE-4.5-300B-A47B model, comparison of WINT2 vs WINT4 performance:
|
||||
| IFEval |500|88.17 | 85.40 |
|
||||
|BBH|6511|94.43|92.02|
|
||||
|DROP|9536|91.17|89.97|
|
||||
|GSM8K|1319|96.21|95.98|
|
||||
|CMath|600|96.50|96.00|
|
||||
|CMMLU|11477|89.92|86.22|
|
@@ -5,8 +5,8 @@
|
||||
##目前支持思考链的模型
|
||||
| 模型名称 | 解析器名称 | 默认开启思考链 |
|
||||
|---------------|-------------|---------|
|
||||
| ernie-45-vl | ernie-45-vl | ✓ |
|
||||
| ernie-lite-vl | ernie-45-vl | ✓ |
|
||||
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | ernie-45-vl | ✓ |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✓ |
|
||||
|
||||
思考模型需要指定解析器,以便于对思考内容进行解析. 通过`enable_thinking=False` 参数可以关闭模型思考模式.
|
||||
|
||||
|
@@ -61,7 +61,7 @@ FastDeploy 与 OpenAI 协议的请求参数差异如下,其余请求参数会
|
||||
- `stream_options`: Optional[StreamOptions] = None
|
||||
- `temperature`: Optional[float] = None
|
||||
- `top_p`: Optional[float] = None
|
||||
- `metadata`: Optional[dict] = None (仅在v1/chat/compeltions中支持,用于配置额外参数, 如meta_data={"enable_thinking": True})
|
||||
- `metadata`: Optional[dict] = None (仅在v1/chat/compeltions中支持,用于配置额外参数, 如metadata={"enable_thinking": True})
|
||||
- `min_tokens`: Optional[int] = 1 最小生成的Token个数
|
||||
- `reasoning_max_tokens`: Optional[int] = None 思考内容最大Token数,默认与max_tokens一致
|
||||
- `enable_thinking`: Optional[bool] = True 支持深度思考的模型是否打开思考
|
||||
|
@@ -26,7 +26,7 @@
|
||||
| ```kv_cache_ratio``` | `float` | KVCache块按kv_cache_ratio比例分给Prefill阶段和Decode阶段, 默认0.75 |
|
||||
| ```enable_prefix_caching``` | `bool` | 是否开启Prefix Caching,默认False |
|
||||
| ```swap_space``` | `float` | 开启Prefix Caching时,用于swap KVCache的CPU内存大小,单位GB,默认None |
|
||||
| ```enable_chunk_prefill``` | `bool` | 开启Chunked Prefill,默认False |
|
||||
| ```enable_chunked_prefill``` | `bool` | 开启Chunked Prefill,默认False |
|
||||
| ```max_num_partial_prefills``` | `int` | 开启Chunked Prefill时,Prefill阶段的最大并发数,默认1 |
|
||||
| ```max_long_partial_prefills``` | `int` | 开启Chunked Prefill时,Prefill阶段并发中包启的最多长请求数,默认1 |
|
||||
| ```long_prefill_token_threshold``` | `int` | 开启Chunked Prefill时,请求Token数超过此值的请求被视为长请求,默认为max_model_len*0.04 |
|
||||
@@ -113,5 +113,5 @@ FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算
|
||||
...
|
||||
```
|
||||
- 当开启 ```use_cudagraph``` 时,暂时只支持单卡推理,即 ```tensor_parallel_size``` 设为1。
|
||||
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunk_prefill``` 。
|
||||
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunked_prefill``` 。
|
||||
- 当开启 ```use_cudagraph``` 后,size小于等于 ```max_capture_batch_size``` 的batch会由CudaGraph来执行前向计算,大于 ```max_capture_batch_size``` 的batch会由原本的动态图/静态图执行前向计算。如果希望所有batch size均由CudaGraph来执行,```max_capture_batch_size``` 的值建议与 ```max_num_seqs``` 一致。```max_capture_batch_size``` 大于 ```max_num_seqs``` 会导致浪费,会多捕获一些推理时不会遇到的batch,占用更多时间与显存。
|
||||
|
@@ -15,11 +15,14 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# suppress warning log from paddlepaddle
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
# suppress log from aistudio
|
||||
os.environ["AISTUDIO_LOG"] = "critical"
|
||||
from fastdeploy.utils import version
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
|
||||
@@ -30,3 +33,48 @@ try:
|
||||
use_triton_in_paddle.make_triton_compatible_with_paddle()
|
||||
except ImportError:
|
||||
pass
|
||||
# TODO(tangbinhan): remove this code
|
||||
|
||||
|
||||
def _patch_fastsafetensors():
|
||||
try:
|
||||
file_path = subprocess.check_output([
|
||||
sys.executable, "-c", "import fastsafetensors, os; \
|
||||
print(os.path.join(os.path.dirname(fastsafetensors.__file__), \
|
||||
'frameworks', '_paddle.py'))"
|
||||
]).decode().strip()
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
content = f.read()
|
||||
if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content:
|
||||
return
|
||||
|
||||
modified = False
|
||||
if "DType.U16: DType.BF16," not in content:
|
||||
lines = content.splitlines()
|
||||
new_lines = []
|
||||
inside_block = False
|
||||
for line in lines:
|
||||
new_lines.append(line)
|
||||
if 'need_workaround_dtypes: Dict[DType, DType] = {' in line:
|
||||
inside_block = True
|
||||
elif inside_block and '}' in line:
|
||||
new_lines.insert(-1, ' DType.U16: DType.BF16,')
|
||||
inside_block = False
|
||||
modified = True
|
||||
content = "\n".join(new_lines)
|
||||
|
||||
if "DType.I8: paddle.uint8," in content:
|
||||
content = content.replace("DType.I8: paddle.uint8,",
|
||||
"DType.U8: paddle.uint8,")
|
||||
modified = True
|
||||
|
||||
if modified:
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content + "\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to patch fastsafetensors: {e}")
|
||||
|
||||
|
||||
_patch_fastsafetensors()
|
||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
@@ -51,7 +51,6 @@ class ModelConfig(PretrainedConfig):
|
||||
top_p = 0.0
|
||||
temperature = 1.0
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
penalty_score = 1.0
|
||||
frequency_score = 0.0
|
||||
presence_score = 0.0
|
||||
@@ -70,7 +69,6 @@ class ModelConfig(PretrainedConfig):
|
||||
max_seq_len: int = 512,
|
||||
initializer_range: float = 0.02,
|
||||
use_rope=True,
|
||||
use_fast_ffn: bool = False,
|
||||
rope_theta: int = 10000,
|
||||
rope_3d: bool = False,
|
||||
ori_vocab_size: int | None = None,
|
||||
@@ -105,7 +103,6 @@ class ModelConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.use_rope = use_rope
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.rope_theta = rope_theta
|
||||
self.ori_vocab_size = ori_vocab_size or vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
@@ -142,6 +139,7 @@ class MoEConfig:
|
||||
moe_num_shared_experts = (0, )
|
||||
moe_layer_start_index = 0
|
||||
moe_layer_end_index = None
|
||||
moe_use_aux_free: bool = False
|
||||
num_max_dispatch_tokens_per_rank = 256
|
||||
im_patch_id = (
|
||||
100295 # multimodality, TODO(liuyuanle): read from config.json
|
||||
@@ -163,7 +161,6 @@ class ParallelConfig:
|
||||
# The embedding weight distributed on your gpu cards is divided by row or column.
|
||||
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
|
||||
# but hidden_size can, we can consider split embedding weight by column.
|
||||
column_cut = False # (bool, optional)
|
||||
"""
|
||||
From old wersion worker args
|
||||
TODO(gongshaotian): Reclassify
|
||||
@@ -194,18 +191,13 @@ class ParallelConfig:
|
||||
engine_pid: Optional[int] = None
|
||||
# Do profile or not
|
||||
do_profile: bool = False
|
||||
# Dynamic load weight or not
|
||||
dynamic_load_weight: bool = False
|
||||
#
|
||||
pad_token_id: int = -1
|
||||
#
|
||||
eos_tokens_lens: int = 2
|
||||
# Enable chunked prefill
|
||||
enable_chunked_prefill: str = "store_true"
|
||||
"""
|
||||
- APPEND_ATTN:
|
||||
"""
|
||||
attention_backend: str = "APPEND_ATTN"
|
||||
|
||||
max_num_batched_tokens: int = 2048
|
||||
# enable prefix cache
|
||||
enable_prefix_caching = None
|
||||
@@ -354,9 +346,27 @@ class GraphOptimizationConfig:
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
Configuration for loading parameter
|
||||
Configuration for dynamic weight loading strategies
|
||||
|
||||
Attributes:
|
||||
dynamic_load_weight: Whether to enable dynamic weight loading
|
||||
load_strategy: Specifies the weight loading method when enabled:
|
||||
- 'ipc': Real-time IPC streaming with automatic resharding
|
||||
- 'ipc_no_reshard': Real-time IPC streaming without weight process
|
||||
- 'ipc_snapshot': Load from disk snapshot of IPC weights
|
||||
- 'meta': provide RL traing worker, no_weights_load
|
||||
- None: No dynamic loading
|
||||
"""
|
||||
pass
|
||||
use_fastsafetensor: bool = False
|
||||
dynamic_load_weight: bool = False
|
||||
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.load_strategy is not None and not self.dynamic_load_weight:
|
||||
raise ValueError("Load strategy requires dynamic_load_weight=True")
|
||||
|
||||
if self.dynamic_load_weight and self.load_strategy is None:
|
||||
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -392,7 +402,7 @@ class FDConfig:
|
||||
init=True) # type: ignore
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True)
|
||||
quant_config: Optional[QuantConfigBase] = None
|
||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||
moe_config: MoEConfig = field(default=None, init=True) # type: ignore
|
||||
|
@@ -16,48 +16,54 @@
|
||||
|
||||
import time
|
||||
import os
|
||||
import subprocess
|
||||
import signal
|
||||
import multiprocessing
|
||||
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
|
||||
model_name_or_path = "./models/eb45t02/"
|
||||
|
||||
model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
|
||||
|
||||
prefill_cmd = (f"FD_LOG_DIR=log_prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py"
|
||||
+ f" --model {model_name_or_path} --port 9811"
|
||||
+ f" --splitwise-role prefill --tensor-parallel-size 4"
|
||||
+ f" --engine-worker-queue-port 6676 --cache-queue-port 55663")
|
||||
def start_decode(model_name_or_path):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
os.environ["FD_LOG_DIR"] = "log_decode"
|
||||
llm_decode = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=1,
|
||||
splitwise_role="decode",
|
||||
engine_worker_queue_port=6678,
|
||||
innode_prefill_ports=[6676],
|
||||
cache_queue_port=55668
|
||||
)
|
||||
return llm_decode
|
||||
|
||||
prefill_instance = subprocess.Popen(
|
||||
prefill_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
def start_prefill(model_name_or_path):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["FD_LOG_DIR"] = "log_prefill"
|
||||
llm_prefill = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=1,
|
||||
splitwise_role="prefill",
|
||||
engine_worker_queue_port=6677,
|
||||
cache_queue_port=55667,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
prefill = multiprocessing.Process(
|
||||
target=start_prefill,
|
||||
args=(model_name_or_path,)).start()
|
||||
time.sleep(10)
|
||||
llm_decode = start_decode(model_name_or_path)
|
||||
|
||||
output = llm_decode.generate(prompts=["who are you?", "what can you do?"], use_tqdm=True)
|
||||
print(output)
|
||||
|
||||
decode.join()
|
||||
|
||||
|
||||
# # 超参设置
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
|
||||
os.environ["FD_LOG_DIR"] = "log_decode"
|
||||
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
|
||||
llm_decode = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=4,
|
||||
splitwise_role="decode",
|
||||
engine_worker_queue_port=6678,
|
||||
innode_prefill_ports=[6676],
|
||||
cache_queue_port=55668
|
||||
)
|
||||
|
||||
|
||||
output = llm_decode.generate(prompts=["who are you?", "what can you do?"], use_tqdm=True)
|
||||
print(output)
|
||||
|
||||
|
||||
os.killpg(prefill_instance.pid, signal.SIGTERM)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -17,13 +17,15 @@
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
try:
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
except:
|
||||
tensor_model_parallel_all_reduce=None
|
@@ -87,10 +87,14 @@ class EngineArgs:
|
||||
"""
|
||||
Configuration for speculative execution.
|
||||
"""
|
||||
dynamic_load_weight: int = 0
|
||||
dynamic_load_weight: bool = False
|
||||
"""
|
||||
dynamic load weight
|
||||
"""
|
||||
load_strategy: str = "meta"
|
||||
"""
|
||||
dynamic load weight strategy
|
||||
"""
|
||||
quantization: str = None
|
||||
guided_decoding_backend: str = "off"
|
||||
"""
|
||||
@@ -364,13 +368,16 @@ class EngineArgs:
|
||||
type=json.loads,
|
||||
default=EngineArgs.speculative_config,
|
||||
help="Configuration for speculative execution.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--dynamic-load-weight",
|
||||
type=int,
|
||||
action='store_true',
|
||||
default=EngineArgs.dynamic_load_weight,
|
||||
help="Flag to indicate whether to load weight dynamically.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--load-strategy",
|
||||
type=str,
|
||||
default=EngineArgs.load_strategy,
|
||||
help="Flag to dynamic load strategy.")
|
||||
model_group.add_argument("--engine-worker-queue-port",
|
||||
type=int,
|
||||
default=EngineArgs.engine_worker_queue_port,
|
||||
@@ -383,6 +390,7 @@ class EngineArgs:
|
||||
"default is None. The priority of this configuration "\
|
||||
"is lower than that of the config file. " \
|
||||
"More complex quantization methods need to be configured via the config file.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--enable-static-graph-inference",
|
||||
action='store_true',
|
||||
@@ -668,8 +676,9 @@ class EngineArgs:
|
||||
"""
|
||||
return ModelConfig(model_name_or_path=self.model,
|
||||
config_json_file=self.model_config_name,
|
||||
quantization=self.quantization,
|
||||
dynamic_load_weight=self.dynamic_load_weight,
|
||||
quantization=self.quantization)
|
||||
load_strategy=self.load_strategy)
|
||||
|
||||
def create_cache_config(self, model_cfg) -> CacheConfig:
|
||||
"""
|
||||
@@ -749,6 +758,9 @@ class EngineArgs:
|
||||
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
|
||||
assert not (self.use_cudagraph and self.enable_prefix_caching), \
|
||||
"Prefix caching cannot be used with CUDA graph"
|
||||
|
||||
return Config(
|
||||
model_name_or_path=self.model,
|
||||
model_config=model_cfg,
|
||||
|
@@ -17,6 +17,7 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastdeploy import envs
|
||||
@@ -41,7 +42,8 @@ class ModelConfig:
|
||||
def __init__(self,
|
||||
model_name_or_path: str,
|
||||
config_json_file: str = "config.json",
|
||||
dynamic_load_weight: int = 0,
|
||||
dynamic_load_weight: bool = False,
|
||||
load_strategy: str="meta",
|
||||
quantization: str = None,
|
||||
download_dir: Optional[str] = None):
|
||||
"""
|
||||
@@ -55,6 +57,7 @@ class ModelConfig:
|
||||
self.model_dir = model_name_or_path
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
|
||||
self.dynamic_load_weight = dynamic_load_weight
|
||||
self.load_strategy = load_strategy
|
||||
self.quantization = quantization
|
||||
|
||||
config_file = os.path.join(model_name_or_path, config_json_file)
|
||||
@@ -465,7 +468,63 @@ class ParallelConfig:
|
||||
llm_logger.info("Parallel Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("==================")
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitConfig:
|
||||
"""
|
||||
Configuration for tracking version information from version.txt
|
||||
|
||||
Attributes:
|
||||
fastdeploy_commit: Full FastDeploy git commit hash
|
||||
paddle_version: PaddlePaddle version string
|
||||
paddle_commit: PaddlePaddle git commit hash
|
||||
cuda_version: CUDA version string
|
||||
compiler_version: CXX compiler version string
|
||||
"""
|
||||
fastdeploy_commit: str = ""
|
||||
paddle_version: str = ""
|
||||
paddle_commit: str = ""
|
||||
cuda_version: str = ""
|
||||
compiler_version: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Automatically load version info when initialized"""
|
||||
self._load_from_version_file()
|
||||
|
||||
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
|
||||
"""Internal method to load version info from file"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("fastdeploy GIT COMMIT ID:"):
|
||||
self.fastdeploy_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle version:"):
|
||||
self.paddle_version = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle GIT COMMIT ID:"):
|
||||
self.paddle_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("CUDA version:"):
|
||||
self.cuda_version = line.split(":")[1].strip()
|
||||
elif line.startswith("CXX compiler version:"):
|
||||
self.compiler_version = line.split(":")[1].strip()
|
||||
except FileNotFoundError:
|
||||
llm_logger.info(f"Warning: Version file not found at {file_path}")
|
||||
except Exception as e:
|
||||
llm_logger.info(f"Warning: Could not read version file - {str(e)}")
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Fasedeploy Commit Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -500,6 +559,7 @@ class Config:
|
||||
cache_config: CacheConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
commit_config: CommitConfig = CommitConfig(),
|
||||
model_name_or_path: str = None,
|
||||
tokenizer: str = None,
|
||||
tensor_parallel_size: int = 8,
|
||||
@@ -559,6 +619,7 @@ class Config:
|
||||
self.cache_config = cache_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.parallel_config = parallel_config
|
||||
self.commit_config = commit_config
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -584,12 +645,10 @@ class Config:
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
|
||||
|
||||
if self.innode_prefill_ports is not None:
|
||||
if not isinstance(self.innode_prefill_ports, list):
|
||||
ports = str(self.innode_prefill_ports).split(',')
|
||||
self.innode_prefill_ports = [int(port) for port in ports]
|
||||
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
@@ -728,7 +787,7 @@ class Config:
|
||||
), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
pass
|
||||
import xgrammar # noqa
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
@@ -749,7 +808,11 @@ class Config:
|
||||
if k == "generation_config" and v is not None:
|
||||
for gck, gcv in v.to_dict().items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
|
||||
elif (k == "cache_config" or
|
||||
k == "model_config" or
|
||||
k == "scheduler_config" or
|
||||
k == "parallel_config" or
|
||||
k == "commit_config"):
|
||||
v.print()
|
||||
else:
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
|
@@ -286,6 +286,8 @@ class LLMEngine(object):
|
||||
while self.running:
|
||||
try:
|
||||
results = self.scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
time.sleep(0.001)
|
||||
for request_id, contents in results.items():
|
||||
for result in contents:
|
||||
self.zmq_server.send_multipart(request_id, result)
|
||||
@@ -444,8 +446,8 @@ class LLMEngine(object):
|
||||
enable_thinking = None
|
||||
if kwargs is not None:
|
||||
enable_thinking = kwargs.get("enable_thinking", None)
|
||||
request = self.data_processor.process_request(request,
|
||||
self.cfg.max_model_len, enable_thinking=enable_thinking)
|
||||
request = self.data_processor.process_request(
|
||||
request, self.cfg.max_model_len, enable_thinking=enable_thinking)
|
||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||
input_ids_len = request.prompt_token_ids_len
|
||||
request.set(
|
||||
@@ -453,7 +455,8 @@ class LLMEngine(object):
|
||||
min(self.cfg.max_model_len - input_ids_len,
|
||||
request.get("max_tokens")))
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
|
||||
default_reasoning_max_tokens = max(
|
||||
int(request.get("max_tokens") * 0.8), 1)
|
||||
request.set("reasoning_max_tokens", default_reasoning_max_tokens)
|
||||
min_tokens = request.get("min_tokens")
|
||||
if input_ids_len + min_tokens >= self.cfg.max_model_len:
|
||||
@@ -963,8 +966,8 @@ class LLMEngine(object):
|
||||
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
|
||||
"FLAGS_use_append_attn": 1,
|
||||
"NCCL_ALGO": "Ring",
|
||||
"FLAGS_hardamard_moe_block_size": 128,
|
||||
"FLAGS_max_partition_size": 32768,
|
||||
"FLAGS_hardamard_moe_block_size": 128,
|
||||
}
|
||||
# environment variables needed by Dy2St
|
||||
variables.update({
|
||||
@@ -1017,6 +1020,12 @@ class LLMEngine(object):
|
||||
worker_path = "../worker/vl_worker_process.py"
|
||||
py_script = os.path.join(current_dir_path, worker_path)
|
||||
|
||||
ori_vocab_size = (
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, 'sp_model')
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
)
|
||||
|
||||
arguments = (
|
||||
f" --nnodes {str(self.cfg.nnode)}"
|
||||
f" --devices {self.cfg.device_ids} {py_script}"
|
||||
@@ -1037,13 +1046,14 @@ class LLMEngine(object):
|
||||
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
|
||||
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
|
||||
f" --quantization {self.cfg.model_config.quantization}"
|
||||
f" --ori_vocab_size {len(self.data_processor.tokenizer)}"
|
||||
f" --ori_vocab_size {ori_vocab_size}"
|
||||
f" --speculative_method {self.cfg.speculative_config.method}"
|
||||
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
|
||||
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
||||
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
||||
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}")
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.model_config.load_strategy}")
|
||||
|
||||
worker_append_flag = {
|
||||
"enable_expert_parallel":
|
||||
@@ -1188,8 +1198,9 @@ class LLMEngine(object):
|
||||
line = line.decode('utf-8', errors='ignore')
|
||||
if self.worker_init_status.get("finished", False):
|
||||
break
|
||||
if match := re.search(r'Loading checkpoint shards:\s*(\d+)',
|
||||
line):
|
||||
if match := re.search(
|
||||
r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)',
|
||||
line):
|
||||
self.worker_init_status["weight_loadding"] = eval(
|
||||
match.group(1)) * 1.0 / 100
|
||||
elif (match := re.search(r'Start load layer (\d+)',
|
||||
|
@@ -122,7 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
@@ -220,6 +220,9 @@ class OpenAIServingChat:
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
|
||||
choice.delta.token_ids = output["token_ids"]
|
||||
@@ -335,6 +338,9 @@ class OpenAIServingChat:
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
choices.append(choice)
|
||||
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
|
@@ -82,13 +82,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_MOE_BACKEND":
|
||||
lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
|
||||
|
||||
# Set whether to disable recompute the request when the KV cache is full.
|
||||
"FD_DISABLED_RECOVER":
|
||||
lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
|
||||
|
||||
# Set triton kernel JIT compilation directory.
|
||||
"FD_TRITON_KERNEL_CACHE_DIR":
|
||||
lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
|
||||
|
||||
# Whether transition from standalone PD decoupling to centralized inference
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
|
||||
|
||||
# Whether to use fastsafetensor load weight (0 or 1)
|
||||
"FD_USE_FASTSAFETENSOR":
|
||||
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -27,6 +27,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class ErnieProcessor(BaseDataProcessor):
|
||||
"""
|
||||
初始化模型实例。
|
||||
@@ -160,6 +161,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
if request.get('prompt'):
|
||||
prompt = request.get('prompt')
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
|
||||
tokens = self.tokenizer.tokenize(prompt)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request['prompt_token_ids'] = token_ids
|
||||
|
@@ -82,6 +82,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
self.vocab_file = vocab_file
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(vocab_file)
|
||||
# pre-process map-type all spec token for decode accelerate.
|
||||
|
||||
@property
|
||||
def space_token(self):
|
||||
@@ -136,14 +137,19 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
"""doc"""
|
||||
return self.sp_model.id_to_piece(id)
|
||||
|
||||
def spec_init(self):
|
||||
if not hasattr(self, "all_spec_tok"):
|
||||
self.all_spec_tok = set(self.all_special_tokens)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
self.spec_init()
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
# prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if token in self.all_spec_tok:
|
||||
# if not prev_is_special:
|
||||
# out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
@@ -210,13 +216,14 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
# if isinstance(t, AddedToken)
|
||||
# )
|
||||
|
||||
self.spec_init()
|
||||
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
||||
|
||||
# TODO: should this be in the base class?
|
||||
if hasattr(self, "do_lower_case") and self.do_lower_case:
|
||||
# convert non-special tokens to lowercase
|
||||
escaped_special_toks = [
|
||||
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
|
||||
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
|
||||
]
|
||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||
|
@@ -25,6 +25,7 @@ from fastdeploy.utils import data_processor_logger
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class BaseDataProcessor(ABC):
|
||||
"""base class for data processor"""
|
||||
|
||||
|
@@ -12,14 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .attention import Attention
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
from .mla_attention_backend import MLAAttentionBackend
|
||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||
from .xpu_attn_backend import XPUAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend"
|
||||
"AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
||||
"MLAAttentionBackend", "FlashAttentionBackend"
|
||||
]
|
||||
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
@@ -187,6 +187,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
|
@@ -111,6 +111,8 @@ class Attention(nn.Layer):
|
||||
k: paddle.Tensor = None,
|
||||
v: paddle.Tensor = None,
|
||||
qkv: paddle.Tensor = None,
|
||||
compressed_kv: paddle.Tensor = None,
|
||||
k_pe: paddle.Tensor = None,
|
||||
forward_meta: ForwardMeta = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -120,12 +122,16 @@ class Attention(nn.Layer):
|
||||
k: the key tensor
|
||||
v: the value tensor
|
||||
forward_meta: the forward meta data
|
||||
compressed_kv: optional compressed key-value cache (for MLA)
|
||||
k_pe: optional key positional encoding (for MLA)
|
||||
"""
|
||||
return forward_meta.attn_backend.forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
self,
|
||||
forward_meta,
|
||||
)
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from functools import cache
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import _Backend, current_platform
|
||||
from fastdeploy.utils import resolve_obj_from_strname
|
||||
|
||||
@@ -40,6 +41,7 @@ def _get_attn_backend(selected_backend: str) -> object:
|
||||
return resolve_obj_from_strname(attention_cls)
|
||||
|
||||
|
||||
def get_attention_backend(selected_backend):
|
||||
"""Selects which attention backend ."""
|
||||
return _get_attn_backend(selected_backend)
|
||||
def get_attention_backend() -> object:
|
||||
"""Selects which attention backend."""
|
||||
attention_backend = envs.FD_ATTENTION_BACKEND
|
||||
return _get_attn_backend(attention_backend)
|
||||
|
@@ -46,6 +46,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
@@ -56,6 +58,8 @@ class AttentionBackend(ABC):
|
||||
k: The key tensor.
|
||||
v: The value tensor.
|
||||
layer: The layer that will be used for the forward.
|
||||
compressed_kv: optional compressed key-value cache (for MLA)
|
||||
k_pe: optional key positional encoding (for MLA)
|
||||
forward_meta: The forward metadata.
|
||||
"""
|
||||
if forward_meta.forward_mode.is_mixed():
|
||||
@@ -64,6 +68,8 @@ class AttentionBackend(ABC):
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
@@ -73,6 +79,8 @@ class AttentionBackend(ABC):
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
@@ -82,6 +90,8 @@ class AttentionBackend(ABC):
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
@@ -92,6 +102,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
@@ -104,6 +116,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
@@ -116,6 +130,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
|
247
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
247
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
try:
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
except:
|
||||
flash_attention_v3_varlen = None
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
pre_cache_num_blocks_cpu = None
|
||||
kv_token_num_cpu = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
||||
fd_config.parallel_config.expert_parallel_rank
|
||||
if self.device_id is None:
|
||||
self.device_id = device_id
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[device_id]
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.set_max_lengths[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0],
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)
|
||||
res = flash_attention_v3_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.set_max_lengths[0],
|
||||
max_seqlen_k=metadata.set_max_lengths[3],
|
||||
causal=self.causal,
|
||||
)[0].reshape([-1, self.hidden_size])
|
||||
return res
|
@@ -0,0 +1,490 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, init_signal_layerwise,
|
||||
open_shm_and_get_meta_signal)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
|
||||
multi_head_latent_attention,
|
||||
prefill_mla_write_cache)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
"""
|
||||
"""
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class MLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
MLA Attention Backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int) -> None:
|
||||
"""
|
||||
MLAAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: MLAAttentionMetadata = None
|
||||
|
||||
# 基础配置
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (10000.0
|
||||
if fd_config.model_config.rope_theta is None
|
||||
else fd_config.model_config.rope_theta)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
# For Multi Head Latent Attention
|
||||
self.kv_lora_rank: int = fd_config.model_config.deepseekv3.kv_lora_rank
|
||||
self.qk_rope_head_dim: int = fd_config.model_config.deepseekv3.qk_rope_head_dim
|
||||
self.qk_head_dim: int = fd_config.model_config.deepseekv3.qk_nope_head_dim \
|
||||
+ fd_config.model_config.deepseekv3.qk_rope_head_dim
|
||||
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
|
||||
if fd_config.model_config.deepseekv3.rope_scaling:
|
||||
mscale_all_dim = fd_config.model_config.deepseekv3.rope_scaling.get(
|
||||
"mscale_all_dim", False) # 1.0
|
||||
scaling_factor = fd_config.model_config.deepseekv3.rope_scaling[
|
||||
"factor"] # 40
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
if self.device_id is None:
|
||||
self.device_id = self.rank
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[self.rank]
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
elif metadata._dtype == "float16":
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(self,
|
||||
max_num_blocks: int) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Calculate kv cache shape for MLA
|
||||
"""
|
||||
return (max_num_blocks, 1, self.block_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Prefill阶段的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
# 写入缓存
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
getattr(forward_meta, 'max_input_length', -1),
|
||||
)
|
||||
|
||||
# Flash注意力计算
|
||||
fmha_out = flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Decode阶段的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
# 获取推测解码参数
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||
|
||||
# 写入缓存
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = multi_head_latent_attention(
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.
|
||||
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
None, # cache_k_quant_scales
|
||||
None, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # out_shifts
|
||||
None, # out_smooths
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
"none", # cache_quant_type
|
||||
self.kv_lora_rank,
|
||||
self.max_seq_len,
|
||||
self.attn_softmax_scale,
|
||||
0.0, # quant_max_bound
|
||||
0.0, # quant_min_bound
|
||||
0.0, # out_linear_in_scale
|
||||
speculate_max_tokens,
|
||||
True, # causal
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
return fmha_out
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Mixed模式的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||
|
||||
decode_stage = forward_meta.is_decode_batch
|
||||
prefill_stage = not (forward_meta.is_decode_batch)
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
if prefill_stage:
|
||||
# 写入缓存
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
||||
# FA
|
||||
fmha_out = flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
||||
# Decode
|
||||
if decode_stage:
|
||||
# mla写入缓存
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = multi_head_latent_attention(
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.
|
||||
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
None, # cache_k_quant_scales
|
||||
None, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # out_shifts
|
||||
None, # out_smooths
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
"none", # cache_quant_type
|
||||
self.kv_lora_rank,
|
||||
self.max_seq_len,
|
||||
self.attn_softmax_scale,
|
||||
0.0, # quant_max_bound
|
||||
0.0, # quant_min_bound
|
||||
0.0, # out_linear_in_scale
|
||||
speculate_max_tokens,
|
||||
True, # causal
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
return fmha_out
|
@@ -17,10 +17,16 @@
|
||||
from .append_attention import append_attention
|
||||
from .get_block_shape_and_split_kv_block import \
|
||||
get_block_shape_and_split_kv_block
|
||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||
from .init_signal_layerwise import init_signal_layerwise
|
||||
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
|
||||
from .pre_cache_len_concat import pre_cache_len_concat
|
||||
|
||||
__all__ = [
|
||||
"get_block_shape_and_split_kv_block", "append_attention",
|
||||
"open_shm_and_get_meta_signal", "init_signal_layerwise"
|
||||
"get_block_shape_and_split_kv_block",
|
||||
"append_attention",
|
||||
"open_shm_and_get_meta_signal",
|
||||
"init_signal_layerwise",
|
||||
"gqa_rope_write_cache",
|
||||
"pre_cache_len_concat",
|
||||
]
|
||||
|
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def gqa_rope_write_cache(
|
||||
qkv: paddle.Tensor,
|
||||
key_cache: paddle.Tensor,
|
||||
value_cache: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
cu_seqlens_k: paddle.Tensor,
|
||||
rotary_embs: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
padding_offsets: paddle.Tensor,
|
||||
cum_offsets: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
kv_batch_ids: paddle.Tensor,
|
||||
kv_tile_ids_per_batch: paddle.Tensor,
|
||||
kv_num_blocks: paddle.Tensor,
|
||||
cache_batch_ids: paddle.Tensor,
|
||||
cache_tile_ids_per_batch: paddle.Tensor,
|
||||
cache_num_blocks: paddle.Tensor,
|
||||
cache_k_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_zp: Optional[paddle.Tensor] = None,
|
||||
cache_v_zp: Optional[paddle.Tensor] = None,
|
||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
cache_quant_type: str = "none"):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||
q, k, v, qkv_ = gqa_rope_write_cache(
|
||||
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
|
||||
rotary_embs, seq_lens_this_time, seq_lens_encoder,
|
||||
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
|
||||
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
|
||||
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
|
||||
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
|
||||
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
|
||||
kv_token_num, max_seq_len, cache_quant_type)
|
||||
return q, k, v, qkv_
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License,
|
||||
Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
max_dec_len: int = 0,
|
||||
block_size: int = 64):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
|
||||
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
|
||||
max_dec_len, block_size)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
@@ -149,6 +149,8 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
|
@@ -41,16 +41,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
Create weights for linear layer on XPU
|
||||
"""
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype="float32",
|
||||
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
@@ -28,12 +33,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim=768,
|
||||
params_dtype="bfloat16",
|
||||
fd_config: FDConfig,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int = 768,
|
||||
params_dtype: str = "bfloat16",
|
||||
prefix="",
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the VocabParallelEmbedding layer for the model.
|
||||
|
||||
@@ -41,28 +46,28 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings : vocabulary size.
|
||||
embedding_dim : size of hidden state.
|
||||
params_dtype : data type of parameters.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
num_embeddings (int) : vocabulary size.
|
||||
embedding_dim (int) : size of hidden state.
|
||||
params_dtype (str) : data type of parameters.
|
||||
prefix (str): The name of current layer. Defaults to "".
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
self.mp_rank = hcg.get_model_parallel_rank()
|
||||
self.column_cut = fd_config.parallel_config.column_cut
|
||||
self.world_size = hcg.get_model_parallel_world_size()
|
||||
self.ring_id = hcg.get_model_parallel_group().id
|
||||
self.use_rope = fd_config.model_config.use_rope
|
||||
self.rope_head_dim = fd_config.model_config.rope_head_dim
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.hidden_dropout_prob = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range = fd_config.model_config.initializer_range
|
||||
self.sequence_parallel = fd_config.parallel_config.sequence_parallel
|
||||
self.max_position_embeddings = fd_config.model_config.max_position_embeddings
|
||||
self.freeze_embedding = fd_config.model_config.freeze_embedding
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
self.mp_rank: int = hcg.get_model_parallel_rank()
|
||||
self.column_cut = False
|
||||
self.world_size: int = hcg.get_model_parallel_world_size()
|
||||
self.ring_id: int = hcg.get_model_parallel_group().id
|
||||
self.use_rope: bool = fd_config.model_config.use_rope
|
||||
self.rope_head_dim: int = fd_config.model_config.rope_head_dim
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range: float = fd_config.model_config.initializer_range
|
||||
self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel
|
||||
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
||||
self.freeze_embedding: bool = fd_config.model_config.freeze_embedding
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.word_embeddings = nn.Embedding(
|
||||
@@ -109,7 +114,8 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim),
|
||||
dtype="int8")
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -125,7 +131,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
|
||||
def forward(self, ids_remove_padding=None):
|
||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
|
@@ -216,6 +216,14 @@ class ReplicatedLinear(LinearBase):
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
self.quant_method.create_weights(self)
|
||||
self.init_weight()
|
||||
|
||||
|
||||
@@ -259,7 +267,10 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_quant=skip_quant)
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks)
|
||||
self.output_size = divide(
|
||||
output_size,
|
||||
self.nranks) # Split the output_size using TP inference.
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
@@ -282,7 +293,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_weight, split_axis=-1)
|
||||
_set_var_distributed(self.linear_weight, split_axis=1)
|
||||
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
@@ -293,7 +304,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_bias, split_axis=-1)
|
||||
_set_var_distributed(self.linear_bias, split_axis=1)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
@@ -318,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
activation: str = "gelu",
|
||||
use_fast_ffn: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -333,13 +343,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
activation (str): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool): Whether to use a faster FFN implementation.
|
||||
Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.activation = activation
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
super().__init__(fd_config=fd_config,
|
||||
@@ -374,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"gate_proj")
|
||||
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
converted_bias_tensor = paddle.zeros(shape=list(
|
||||
bias_tensor.shape),
|
||||
dtype=bias_tensor.dtype)
|
||||
if not self.use_fast_ffn:
|
||||
converted_bias_tensor = paddle.concat(
|
||||
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
|
||||
else:
|
||||
converted_bias_tensor = bias_tensor
|
||||
state_dict[self.bias_key] = converted_bias_tensor
|
||||
|
||||
if not self.use_fast_ffn:
|
||||
converted_weight_tensor = paddle.concat(
|
||||
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
|
||||
else:
|
||||
converted_weight_tensor = weight_tensor
|
||||
state_dict[self.bias_key] = bias_tensor
|
||||
|
||||
state_dict[self.weight_key] = converted_weight_tensor
|
||||
state_dict[self.weight_key] = weight_tensor
|
||||
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
@@ -413,12 +407,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
self.num_heads = fd_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
input_size = self.embed_dim
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
@@ -448,7 +442,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
weight_tensor = weight_tensor.reshape([
|
||||
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) *
|
||||
(self.head_dim),
|
||||
self.embed_dim,
|
||||
self.hidden_size,
|
||||
])
|
||||
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
|
||||
@@ -513,6 +507,7 @@ class RowParallelLinear(LinearBase):
|
||||
output_size: int = None,
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
reduce_results: bool = True,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -538,10 +533,14 @@ class RowParallelLinear(LinearBase):
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
|
||||
# Split input_size when using TP inference.
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.output_size = output_size
|
||||
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
@@ -551,6 +550,8 @@ class RowParallelLinear(LinearBase):
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
@@ -570,7 +571,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
shape=[self.embed_dim],
|
||||
shape=[self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
@@ -589,7 +590,7 @@ class RowParallelLinear(LinearBase):
|
||||
else:
|
||||
out = paddle.matmul(x, self.linear_weight)
|
||||
|
||||
if self.nranks > 1:
|
||||
if self.reduce_results and self.nranks > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
@@ -28,12 +33,12 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
prefix="",
|
||||
with_bias=False,
|
||||
):
|
||||
fd_config: FDConfig,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
prefix: str = "",
|
||||
with_bias: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parallelized LMhead.
|
||||
|
||||
@@ -43,21 +48,22 @@ class ParallelLMHead(nn.Layer):
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
prefix (str): full name of the layer in the state dict
|
||||
prefix (str): The name of current layer. Defaults to "".
|
||||
with_bias (bool): whether to have bias. Default: False.
|
||||
"""
|
||||
super(ParallelLMHead, self).__init__()
|
||||
self.linear_weight_key = prefix + ".weight"
|
||||
self.linear_weight_key: str = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key = prefix + ".bias"
|
||||
self.linear_bias_key: Optional[str] = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.linear_bias_key: Optional[str] = None
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
@@ -92,7 +98,8 @@ class ParallelLMHead(nn.Layer):
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -122,7 +129,7 @@ class ParallelLMHead(nn.Layer):
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
|
@@ -22,14 +22,35 @@ from paddleformers.utils.log import logger
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from ..utils import get_tensor, create_and_set_parameter
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import create_and_set_parameter, get_tensor
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
||||
moe_expert_reduce, noaux_tc)
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class CutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
@@ -199,23 +220,47 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out = get_moe_scores(gate_out, layer.n_group,
|
||||
layer.topk_group, layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
None, # Use layer.gate_correction_bias in get_moe_scores.
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=True,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if self.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
@@ -234,11 +279,11 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
norm_topk_prob=True,
|
||||
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
@@ -195,8 +195,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
hidden_size = layer.hidden_size
|
||||
num_experts = layer.num_experts
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
|
@@ -17,6 +17,7 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
||||
@@ -25,17 +26,24 @@ from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_method=None):
|
||||
def __init__(self, quant_config=None):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
@@ -52,7 +60,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert layer.quant_method.quant_config.name() == "wint8"
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert algo == "wint8"
|
||||
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
@@ -63,9 +75,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
|
||||
if self.quant_config.name() == "wint8":
|
||||
if algo == "wint8":
|
||||
max_bound = 127
|
||||
elif self.quant_config.name() == "wint4":
|
||||
elif algo == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
@@ -111,15 +123,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(scores,
|
||||
k=top_k,
|
||||
axis=-1,
|
||||
sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -139,14 +149,12 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
@@ -158,10 +166,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
moe_intermediate_size * 2,
|
||||
hidden_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
@@ -193,8 +201,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
@@ -205,10 +214,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
hidden_size,
|
||||
moe_intermediate_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
@@ -324,7 +333,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(scores,
|
||||
@@ -352,13 +360,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
adamard_matrix = create_hadamard_matrix_map[hidden_size]
|
||||
x = paddle.matmul(x.cast("float32"), adamard_matrix)
|
||||
@@ -371,8 +379,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
permute_x = permute_x / quant_activation_scale
|
||||
permute_x = permute_x.astype("float8_e4m3fn")
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
permute_x,
|
||||
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
|
||||
@@ -383,10 +389,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
moe_intermediate_size * 2,
|
||||
hidden_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
@@ -426,8 +432,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
intermediate_cache2 = intermediate_cache2 / quant_activation_scale
|
||||
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn")
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
@@ -439,10 +446,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
hidden_size,
|
||||
moe_intermediate_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
|
@@ -224,6 +224,7 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
|
@@ -30,10 +30,15 @@ class FusedMoE(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
reduce_results: bool = True,
|
||||
moe_intermediate_size: int = -1,
|
||||
num_experts: int = -1,
|
||||
expert_id_offset: int = 0,
|
||||
top_k: int = -1,
|
||||
topk_method: str = "",
|
||||
topk_group: int = -1,
|
||||
n_group: int = -1,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
layer_idx: int = -1,
|
||||
moe_tag: str = "",
|
||||
weight_key_map: dict = {},
|
||||
@@ -49,6 +54,7 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.layer_idx = layer_idx
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_degree
|
||||
@@ -60,28 +66,33 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.moe_config = fd_config.moe_config
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = self.num_experts // self.ep_size
|
||||
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
|
||||
self.top_k = top_k
|
||||
self.hidden_size = self.hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
self.weight_key_map = weight_key_map
|
||||
|
||||
self.use_method = envs.FD_MOE_BACKEND.lower()
|
||||
self.gate_correction_bias = None
|
||||
self.moe_tag = moe_tag
|
||||
|
||||
if self.ep_size > 1:
|
||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||
|
||||
self.expert_id_offset = expert_id_offset
|
||||
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
# used for deepseek_v3
|
||||
self.topk_method = topk_method
|
||||
self.topk_group = topk_group
|
||||
self.n_group = n_group
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
moe_quant_config = fd_config.quant_config
|
||||
self.moe_quant_type = None
|
||||
if moe_quant_config:
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
@@ -90,12 +101,78 @@ class FusedMoE(nn.Layer):
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# It's for RL to build model
|
||||
self.init_moe_weights()
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
|
||||
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
|
||||
, ep_size={self.ep_size}, \
|
||||
tp_size={self.tp_size}.")
|
||||
|
||||
def init_moe_weights(self):
|
||||
"""
|
||||
Initialize the weight shapes and parameters for the MoE layer.
|
||||
Combines weight shape initialization and parameter creation into a single function.
|
||||
"""
|
||||
# Initialize weight shapes
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
gate_weight_shape = [self.hidden_size, self.num_experts]
|
||||
gate_correction_bias_shape = [1, self.num_experts]
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
if self.moe_config.moe_use_aux_free:
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
ffn1_output_dim = self.moe_intermediate_size * 2
|
||||
if self.moe_quant_type in ["fp8", "wint8"]:
|
||||
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
|
||||
else:
|
||||
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
|
||||
|
||||
# Create parameters
|
||||
if self.moe_quant_type == "fp8":
|
||||
#(TODO:gaoziyuan)
|
||||
pass
|
||||
elif self.moe_quant_type == "wint8":
|
||||
self.weight_dtype = "int8"
|
||||
self.init_weight_only_scale()
|
||||
|
||||
# FFN1 parameters
|
||||
self.moe_ffn1_weight = self.create_parameter(
|
||||
shape=ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
# FFN2 parameters
|
||||
self.moe_ffn2_weight = self.create_parameter(
|
||||
shape=ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def init_weight_only_scale(self):
|
||||
"""
|
||||
Initialize the weight scale.
|
||||
"""
|
||||
self.moe_ffn1_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.moe_ffn2_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
def load_experts_weight(self, state_dict: dict,
|
||||
ffn1_expert_weight_key: str,
|
||||
ffn2_expert_weight_key: str):
|
||||
|
@@ -16,9 +16,10 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2
|
||||
|
||||
|
||||
@triton.jit
|
||||
@paddle_use_triton_v2()
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
@@ -31,22 +32,22 @@ def fused_moe_kernel_paddle(
|
||||
num_tokens_post_padded_ptr,
|
||||
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
stride_am: tl.constexpr,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_be: tl.constexpr,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
stride_cm: tl.constexpr,
|
||||
stride_cn: tl.constexpr,
|
||||
stride_asm: tl.constexpr,
|
||||
stride_ask: tl.constexpr,
|
||||
stride_bse: tl.constexpr,
|
||||
stride_bsk: tl.constexpr,
|
||||
stride_bsn: tl.constexpr,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -87,7 +88,7 @@ def fused_moe_kernel_paddle(
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
|
133
fastdeploy/model_executor/layers/mtp_linear.py
Normal file
133
fastdeploy/model_executor/layers/mtp_linear.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
class ParallelEHProjection(nn.Layer):
|
||||
"""
|
||||
"Parallelized Embedding Hidden States Projection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
prefix="",
|
||||
with_bias=False,
|
||||
):
|
||||
"""
|
||||
Parallelized Embedding Hidden States Projection.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
prefix (str): full name of the layer in the state dict
|
||||
"""
|
||||
super(ParallelEHProjection, self).__init__()
|
||||
self.linear_weight_key = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.out_linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
self.out_linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(
|
||||
state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
if self.out_linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.out_linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.linear_bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
Args:
|
||||
input (Tensor): The input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.out_linear(logits)
|
||||
return logits
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
@@ -28,16 +33,16 @@ class RMSNorm(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
prefix="",
|
||||
linear_bias=None,
|
||||
quant_scale=None,
|
||||
begin_norm_axis=1,
|
||||
):
|
||||
fd_config: FDConfig,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-5,
|
||||
prefix: str = "",
|
||||
linear_bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
begin_norm_axis: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the normalization layer.
|
||||
Initializes the RMSNormalization layer.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
@@ -45,33 +50,33 @@ class RMSNorm(nn.Layer):
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
hidden_size (int) : size of hidden state.
|
||||
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
|
||||
weight_key (str): Key name of weight in the pdparams state dict. Defaults to None, means no weight.
|
||||
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
|
||||
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
prefix(str,optional):The name of current layer. Defaults to "".
|
||||
linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
|
||||
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the specified norm_type is not supported.
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
self.prefix = prefix
|
||||
self.hidden_size = hidden_size
|
||||
self.prefix: str = prefix
|
||||
self.hidden_size: int = hidden_size
|
||||
if len(prefix) == 0:
|
||||
self.weight_key = None
|
||||
self.weight_key: Optional[str] = None
|
||||
else:
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.with_weight = self.weight_key is not None
|
||||
self.eps = eps
|
||||
self.norm_func = fused_rms_norm
|
||||
self.linear_bias = linear_bias
|
||||
self.quant_scale = quant_scale
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype = self._dtype
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.eps: float = eps
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = self._dtype
|
||||
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.begin_norm_axis: int = begin_norm_axis
|
||||
|
||||
self.init_weight()
|
||||
|
||||
@@ -88,7 +93,8 @@ class RMSNorm(nn.Layer):
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -102,7 +108,10 @@ class RMSNorm(nn.Layer):
|
||||
self._norm_weight_dtype)
|
||||
self.ln_weight.set_value(weight_tensor)
|
||||
|
||||
def forward(self, x, residual_input=None):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
@@ -140,18 +149,18 @@ class RMSNorm(nn.Layer):
|
||||
|
||||
class LayerNorm(nn.Layer):
|
||||
"""
|
||||
Normalization layer.
|
||||
Initializes the LayerNormalization layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
fd_config: FDConfig,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-5,
|
||||
prefix="",
|
||||
linear_bias=None,
|
||||
quant_scale=None,
|
||||
with_bias=False,
|
||||
linear_bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the normalization layer.
|
||||
@@ -160,35 +169,37 @@ class LayerNorm(nn.Layer):
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
hidden_size (int) : size of hidden state.
|
||||
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
|
||||
with_bias (bool):Whether to include bias or not. Defaults to False.
|
||||
Raises:
|
||||
NotImplementedError: If the specified norm_type is not supported.
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
self.prefix = prefix
|
||||
self.hidden_size = hidden_size
|
||||
self.prefix: str = prefix
|
||||
self.hidden_size: int = hidden_size
|
||||
if len(prefix) == 0:
|
||||
self.weight_key = None
|
||||
self.weight_key: Optional[str] = None
|
||||
else:
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.with_weight = self.weight_key is not None
|
||||
self.bias_key = f"{prefix}.bias"
|
||||
self.with_bias = with_bias
|
||||
self.eps = eps
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.bias_key: str = f"{prefix}.bias"
|
||||
self.with_bias: bool = with_bias
|
||||
self.eps: float = eps
|
||||
self.quant_scale: float = quant_scale
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
|
||||
self.norm_func = fused_layer_norm
|
||||
self.linear_bias = linear_bias
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype = "float32"
|
||||
|
||||
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
|
||||
self.init_weight()
|
||||
|
||||
@@ -212,7 +223,8 @@ class LayerNorm(nn.Layer):
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -233,7 +245,10 @@ class LayerNorm(nn.Layer):
|
||||
self._norm_weight_dtype)
|
||||
self.ln_bias.set_value(bias_tensor)
|
||||
|
||||
def forward(self, x, residual_input=None):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
@@ -259,7 +274,7 @@ class LayerNorm(nn.Layer):
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
|
@@ -15,8 +15,9 @@
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..attention import Attention
|
||||
from ..moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
@@ -132,18 +132,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
if self.quant_config.name() == "wint4":
|
||||
linear_weight_scale_shape[0] *= 2
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
@@ -195,6 +191,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||
weight,
|
||||
algo=self.quant_config.algo,
|
||||
|
@@ -14,13 +14,18 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
|
||||
|
||||
from .utils import CpuGuard
|
||||
|
||||
|
||||
@@ -99,20 +104,164 @@ class QwenRotaryEmbedding:
|
||||
return rot_emb
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
"""
|
||||
"""
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
"""
|
||||
"""
|
||||
return (dim * math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 *
|
||||
math.log(base))
|
||||
|
||||
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
"""
|
||||
"""
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min, max, dim):
|
||||
"""
|
||||
"""
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max -
|
||||
min)
|
||||
ramp_func = paddle.clip(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
class DeepseekScalingRotaryEmbedding(nn.Layer):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
|
||||
Args:
|
||||
rotary_dim(int): Dimension of rotary embeddings (head dimension)
|
||||
max_position_embeddings(int): Original training context length
|
||||
base(float): Base value used to compute the inverse frequencies.
|
||||
scaling_factor(float): Context extension scaling ratio (target_len / original_len)
|
||||
extrapolation_factor(float): Weight for extrapolated frequencies (default=1)
|
||||
attn_factor(float): Attention magnitude scaling factor (default=1)
|
||||
beta_fast(int): High-frequency correction cutoff (default=32)
|
||||
beta_slow(int): Low-frequency correction cutoff (default=1)
|
||||
mscale(float): Primary magnitude scaling factor (default=1)
|
||||
mscale_all_dim(float): Alternate magnitude scaling factor (default=0)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
scaling_factor: float,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
|
||||
self.cos_sin_cache: paddle.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistable=True)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
|
||||
pos_freqs = self.base**(
|
||||
paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) /
|
||||
self.rotary_dim)
|
||||
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> paddle.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = paddle.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=paddle.float32)
|
||||
freqs = paddle.einsum("i,j->ij", t, inv_freq)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
cache = paddle.concat((cos, sin), axis=-1)
|
||||
return cache.cast(self._dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: paddle.Tensor,
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
"""
|
||||
# In-place operations that update the query and key tensors.
|
||||
fused_rotary_position_encoding(query, key, position_ids,
|
||||
self.cos_sin_cache, self.rotary_dim,
|
||||
False)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
def get_rope_impl(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
position_ids: paddle.Tensor,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
The real implementation of get_rope
|
||||
"""
|
||||
|
||||
architecture = model_config.architectures[0]
|
||||
if model_config is not None and model_config is None or architecture.startswith(
|
||||
"Qwen"):
|
||||
if model_config is None or architecture.startswith("Qwen"):
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
@@ -126,10 +275,10 @@ def get_rope_impl(
|
||||
def get_rope_xpu(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: ModelConfig,
|
||||
position_ids: paddle.Tensor,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
In XPU, cos and sin compute must be done on cpu
|
||||
"""
|
||||
@@ -143,12 +292,27 @@ def get_rope_xpu(
|
||||
def get_rope(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: ModelConfig,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
position_ids: paddle.Tensor,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor: int = 1,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
The warpper of get_rope
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
|
||||
Args:
|
||||
rotary_dim (int):
|
||||
Dimension of rotary embeddings (head dimension)
|
||||
base (float, optional):
|
||||
Base value used to compute the inverse frequencies.
|
||||
Default: 10000.0.
|
||||
position_ids (paddle.Tensor):
|
||||
Tensor containing position indices of input tokens.
|
||||
model_config (Optional[ModelConfig]):
|
||||
Model configuration object containing architecture information.
|
||||
If provided, determines RoPE implementation based on model architecture.
|
||||
partial_rotary_factor (int, optional):
|
||||
Factor controlling partial rotary application.
|
||||
Default: 1 (apply to all dimensions).
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return get_rope_xpu(rotary_dim, base, position_ids, model_config,
|
||||
@@ -255,7 +419,24 @@ def get_rope_3d(
|
||||
paritial_rotary_factor: 1,
|
||||
max_position: 131072,
|
||||
freq_allocation: 2,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
|
||||
Args:
|
||||
rotary_dim (int):
|
||||
Dimension of rotary embeddings (head dimension)
|
||||
base (float, optional):
|
||||
Base value used to compute the inverse frequencies.
|
||||
Default: 10000.0.
|
||||
position_ids (paddle.Tensor):
|
||||
Tensor containing position indices of input tokens.
|
||||
partial_rotary_factor (int, optional):
|
||||
Factor controlling partial rotary application.
|
||||
Default: 1 (apply to all dimensions).
|
||||
max_position: Maximum position index to precompute.
|
||||
freq_allocation: Number of rotary dimensions allocated to temporal axis
|
||||
"""
|
||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
|
||||
paritial_rotary_factor,
|
||||
max_position,
|
||||
|
@@ -377,4 +377,4 @@ def create_and_set_parameter(layer: nn.Layer, name: str,
|
||||
dtype=tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, name).set_value(tensor)
|
||||
getattr(layer, name).set_value(tensor)
|
289
fastdeploy/model_executor/load_weight_utils.py
Normal file
289
fastdeploy/model_executor/load_weight_utils.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.models.tp_utils import \
|
||||
check_tensor_parallel_prerequisites
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def load_ep_checkpoint(model_path: str,
|
||||
config: ModelConfig,
|
||||
return_numpy: bool = False):
|
||||
"""
|
||||
load ep checkpoint
|
||||
"""
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"),
|
||||
"r") as f:
|
||||
weight_list = json.load(f)["weight_map"]
|
||||
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
|
||||
num_local_ffn_keys = []
|
||||
|
||||
for i in range(config.moe_layer_start_index, config.num_layers):
|
||||
for j in range(
|
||||
config.num_experts_start_offset,
|
||||
config.num_experts_start_offset + config.num_experts_per_rank,
|
||||
):
|
||||
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
|
||||
ffn2_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
|
||||
|
||||
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
|
||||
ffn2_quant_key = (
|
||||
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
|
||||
|
||||
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
|
||||
ffn2_scale_key = (
|
||||
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
|
||||
num_local_ffn_keys.append(ffn1_key)
|
||||
num_local_ffn_keys.append(ffn2_key)
|
||||
num_local_ffn_keys.append(ffn1_quant_key)
|
||||
num_local_ffn_keys.append(ffn2_quant_key)
|
||||
num_local_ffn_keys.append(ffn1_scale_key)
|
||||
num_local_ffn_keys.append(ffn2_scale_key)
|
||||
|
||||
for k in num_local_ffn_keys:
|
||||
if k in weight_list:
|
||||
filtered_map[k] = weight_list[k]
|
||||
|
||||
state_dict = {}
|
||||
# Get all safetensor file paths that need to be opened
|
||||
safetensor_paths = set(filtered_map.values())
|
||||
|
||||
# Open each safetensor file sequentially with progress bar
|
||||
for safetensor_path in tqdm(safetensor_paths,
|
||||
desc="Loading safetensor files",
|
||||
unit="file"):
|
||||
with safe_open(os.path.join(model_path, safetensor_path),
|
||||
framework="np",
|
||||
device="cpu") as f:
|
||||
# Check if this file contains keys from filtered_map
|
||||
for k in filtered_map:
|
||||
if filtered_map[k] == safetensor_path and k in f.keys():
|
||||
weight = f.get_tensor(k)
|
||||
if not return_numpy:
|
||||
weight = paddle.Tensor(weight, zero_copy=True)
|
||||
weight = weight._copy_to(
|
||||
paddle.framework._current_expected_place(), False)
|
||||
state_dict[k] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
|
||||
"""
|
||||
safetensors_weights_iterator
|
||||
"""
|
||||
for st_file in tqdm(
|
||||
safe_tensor_list,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
):
|
||||
with safe_open(st_file, framework="np") as f:
|
||||
for name in f.keys():
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(safetensor_list: list[str], ):
|
||||
"""
|
||||
Return an iterator over tensors on GPU from a given safetensor_list.
|
||||
"""
|
||||
world_size = dist.get_world_size()
|
||||
if world_size > 1:
|
||||
pg = dist.get_group()
|
||||
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
|
||||
) else "cpu"
|
||||
|
||||
safetensor_files_sub_lists = [
|
||||
safetensor_list[i:i + world_size]
|
||||
for i in range(0, len(safetensor_list), world_size)
|
||||
]
|
||||
|
||||
for st_file in tqdm(
|
||||
safetensor_files_sub_lists,
|
||||
desc="Loading fastsafetensors checkpoint shards",
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg,
|
||||
device,
|
||||
nogds=True,
|
||||
debug_log=False,
|
||||
framework="paddle")
|
||||
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
try:
|
||||
keys = list(fb.key_to_rank_lidx.keys())
|
||||
for k in keys:
|
||||
t = fb.get_tensor(k)
|
||||
yield k, t
|
||||
finally:
|
||||
fb.close()
|
||||
finally:
|
||||
loader.close()
|
||||
|
||||
|
||||
def load_pre_sharded_checkpoint(model_path: str,
|
||||
local_rank: int,
|
||||
use_fastsafetensor: bool = False):
|
||||
"""
|
||||
load_pre_sharded_checkpoint
|
||||
"""
|
||||
state_dict = {}
|
||||
_, safetensor_files = get_all_safetensors(
|
||||
os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_all_safetensors(model_path: str):
|
||||
"""
|
||||
get_all_safetensors
|
||||
"""
|
||||
safe_model_path = os.path.join(model_path, "model.safetensors")
|
||||
if os.path.exists(safe_model_path):
|
||||
safetensor_list = [safe_model_path]
|
||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||
key_name_list = f.keys()
|
||||
return key_name_list, safetensor_list
|
||||
else:
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"),
|
||||
"r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
os.path.join(model_path, weight_map[weight_name]))
|
||||
key_name_list = list(set(weight_map.keys()))
|
||||
safetensor_list = list(weight_files_in_index)
|
||||
safetensor_list.sort()
|
||||
return key_name_list, safetensor_list
|
||||
|
||||
|
||||
def load_tp_checkpoint_v1(
|
||||
model_path: str,
|
||||
cls: PretrainedModel,
|
||||
fd_config: FDConfig,
|
||||
use_fastsafetensor: bool = True,
|
||||
):
|
||||
"""
|
||||
load_tp_checkpoint_v1
|
||||
"""
|
||||
|
||||
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
|
||||
|
||||
if use_fastsafetensor:
|
||||
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
|
||||
tensor_parallel_filtered_map = {}
|
||||
check_tensor_parallel_prerequisites(
|
||||
fd_config,
|
||||
cls,
|
||||
tensor_parallel_filtered_map,
|
||||
safetensor_keys,
|
||||
)
|
||||
need_tp = True if tensor_parallel_filtered_map else False
|
||||
state_dict = {}
|
||||
for key, weight in weights_iterator:
|
||||
paddle.device.synchronize()
|
||||
if need_tp and key in tensor_parallel_filtered_map:
|
||||
action = tensor_parallel_filtered_map.pop(key)
|
||||
tensor = action(weight).clone()
|
||||
else:
|
||||
tensor = weight.clone()
|
||||
state_dict[key] = tensor
|
||||
weight.value().get_tensor()._clear()
|
||||
return state_dict
|
||||
|
||||
|
||||
def deal_state_dict(state_dict):
|
||||
"""deal_state_dict"""
|
||||
device = paddle.CUDAPinnedPlace()
|
||||
for name, src in state_dict.items():
|
||||
if src._is_initialized() and not isinstance(src.place,
|
||||
paddle.CUDAPinnedPlace):
|
||||
dst = src._copy_to(device, True)
|
||||
dst_tensor = dst.value().get_tensor()
|
||||
src_tensor = src.value().get_tensor()
|
||||
src_tensor._clear()
|
||||
src_tensor._share_data_with(dst_tensor)
|
||||
|
||||
|
||||
def load_composite_checkpoint(
|
||||
model_path: str,
|
||||
cls: PretrainedModel,
|
||||
fd_config: FDConfig,
|
||||
return_numpy=True,
|
||||
):
|
||||
"""
|
||||
# This method supports loading model weights under three parallelism strategies:
|
||||
# 1. Expert Parallel (EP)
|
||||
# 2. Tensor Parallel (TP)
|
||||
# 3. Pre-sharded (pre-split)
|
||||
"""
|
||||
if fd_config.parallel_config.use_ep:
|
||||
state_dict = load_ep_checkpoint(model_path,
|
||||
fd_config.model_config,
|
||||
return_numpy=True)
|
||||
else:
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank")
|
||||
and os.path.isdir(os.path.join(model_path, f))
|
||||
]
|
||||
if len(rank_dirs) > 1:
|
||||
if fd_config.parallel_config.tensor_parallel_degree != len(
|
||||
rank_dirs):
|
||||
raise ValueError(
|
||||
f"Your model only supports loading with tp{len(rank_dirs)}"
|
||||
)
|
||||
state_dict = load_pre_sharded_checkpoint(
|
||||
model_path,
|
||||
fd_config.parallel_config.tensor_parallel_rank,
|
||||
use_fastsafetensor=False,
|
||||
)
|
||||
else:
|
||||
if fd_config.load_config.use_fastsafetensor and (
|
||||
current_platform.available()
|
||||
and current_platform.is_cuda()):
|
||||
state_dict = load_tp_checkpoint_v1(model_path,
|
||||
cls,
|
||||
fd_config,
|
||||
use_fastsafetensor=True)
|
||||
deal_state_dict(state_dict)
|
||||
else:
|
||||
state_dict = load_tp_checkpoint(model_path,
|
||||
cls,
|
||||
fd_config.model_config,
|
||||
return_numpy=return_numpy)
|
||||
if not state_dict:
|
||||
raise ValueError("weight not found in state_dict !")
|
||||
return state_dict
|
@@ -20,6 +20,10 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
||||
from fastdeploy.model_executor.load_weight_utils import \
|
||||
load_composite_checkpoint
|
||||
from fastdeploy.model_executor.models.deepseek_v3 import \
|
||||
DeepSeekV3PretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
||||
Ernie4_5_PretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_mtp import \
|
||||
@@ -28,7 +32,7 @@ from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
|
||||
from fastdeploy.model_executor.models.utils import load_checkpoint
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
|
||||
@@ -36,7 +40,8 @@ MODEL_CLASSES = {
|
||||
"Qwen2ForCausalLM": Qwen2PretrainedModel,
|
||||
"Qwen3ForCausalLM": Qwen3PretrainedModel,
|
||||
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
|
||||
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel
|
||||
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
|
||||
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
|
||||
}
|
||||
|
||||
|
||||
@@ -73,23 +78,43 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
def clean_memory_fragments(self, state_dict: dict) -> None:
|
||||
"""clean_memory_fragments"""
|
||||
if current_platform.is_cuda():
|
||||
if state_dict:
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, paddle.Tensor):
|
||||
v.value().get_tensor()._clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.synchronize()
|
||||
|
||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
context = paddle.LazyGuard()
|
||||
architectures = fd_config.model_config.architectures[0]
|
||||
|
||||
# TODO(gongshaotian): Now, only support safetensor
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# register rl model
|
||||
import fastdeploy.rl
|
||||
architectures = architectures + "RL"
|
||||
|
||||
model_class = MODEL_CLASSES[architectures]
|
||||
state_dict = load_checkpoint(
|
||||
fd_config.parallel_config.model_name_or_path,
|
||||
model_class,
|
||||
fd_config.model_config,
|
||||
return_numpy=True)
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
model = model_cls(fd_config)
|
||||
|
||||
model.eval()
|
||||
model.set_state_dict(state_dict)
|
||||
|
||||
# RL model not need set_state_dict
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
return model
|
||||
|
||||
# TODO(gongshaotian): Now, only support safetensor
|
||||
model_class = MODEL_CLASSES[architectures]
|
||||
state_dict = load_composite_checkpoint(
|
||||
fd_config.parallel_config.model_name_or_path,
|
||||
model_class,
|
||||
fd_config,
|
||||
return_numpy=True,
|
||||
)
|
||||
model.set_state_dict(state_dict)
|
||||
self.clean_memory_fragments(state_dict)
|
||||
return model
|
||||
|
@@ -20,15 +20,6 @@ from pathlib import Path
|
||||
|
||||
from .model_base import ModelForCasualLM, ModelRegistry
|
||||
|
||||
inference_runner_supported_models = [
|
||||
"Ernie4_5_MoeForCausalLM",
|
||||
"Ernie4_5_MTPForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Ernie4_5_ForCausalLM",
|
||||
"Qwen3ForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
def _find_py_files(root_dir):
|
||||
root_path = Path(root_dir)
|
||||
@@ -44,14 +35,14 @@ def _find_py_files(root_dir):
|
||||
return py_files
|
||||
|
||||
|
||||
def auto_models_registry():
|
||||
def auto_models_registry(dir_path,
|
||||
register_path="fastdeploy.model_executor.models"):
|
||||
"""
|
||||
auto registry all models in this folder
|
||||
"""
|
||||
for module_file in _find_py_files(os.path.dirname(__file__)):
|
||||
for module_file in _find_py_files(dir_path):
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f'fastdeploy.model_executor.models.{module_file}')
|
||||
module = importlib.import_module(f'{register_path}.{module_file}')
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if inspect.isclass(attr) and issubclass(
|
||||
@@ -62,4 +53,4 @@ def auto_models_registry():
|
||||
raise ImportError(f"{module_file=} import error")
|
||||
|
||||
|
||||
auto_models_registry()
|
||||
auto_models_registry(os.path.dirname(__file__))
|
||||
|
762
fastdeploy/model_executor/models/deepseek_v3.py
Normal file
762
fastdeploy/model_executor/models/deepseek_v3.py
Normal file
@@ -0,0 +1,762 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
|
||||
ReplicatedLinear, RowParallelLinear)
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import \
|
||||
DeepseekScalingRotaryEmbedding
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
get_position_ids_and_mask_encoder_batch
|
||||
|
||||
|
||||
class DeepSeekV3MLP(nn.Layer):
|
||||
"""
|
||||
DeepSeekV3MLP, for Dense FFN and Shared Experts Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
intermediate_size: int,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.up_gate_proj",
|
||||
input_size=fd_config.model_config.hidden_size,
|
||||
output_size=intermediate_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
input_size=intermediate_size,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
with_bias=False,
|
||||
reduce_results=reduce_results,
|
||||
)
|
||||
|
||||
self.act_fn = SiluAndMul(
|
||||
fd_config=fd_config,
|
||||
bias=None,
|
||||
act_method=fd_config.model_config.hidden_act,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
self.gate_up_proj.load_state_dict(state_dict)
|
||||
self.down_proj.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
"""
|
||||
gate_up_out = self.gate_up_proj(x)
|
||||
act_out = self.act_fn(gate_up_out)
|
||||
down_out = self.down_proj(act_out)
|
||||
return down_out
|
||||
|
||||
|
||||
class DeepSeekV3MoE(nn.Layer):
|
||||
"""
|
||||
DeepSeekV3MoE, for MoE Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, layer_id: int,
|
||||
prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{prefix}.gate.weight",
|
||||
"gate_correction_bias_key":
|
||||
f"{prefix}.gate.e_score_correction_bias",
|
||||
"ffn1_expert_weight_key":
|
||||
f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||
"ffn2_expert_weight_key":
|
||||
f"{prefix}.experts.{{}}.down_proj.weight",
|
||||
}
|
||||
|
||||
self.fused_moe = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
reduce_results=False,
|
||||
moe_intermediate_size=fd_config.model_config.deepseekv3.
|
||||
moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.deepseekv3.n_routed_experts,
|
||||
top_k=fd_config.model_config.deepseekv3.num_experts_per_tok,
|
||||
topk_method=fd_config.model_config.deepseekv3.topk_method,
|
||||
topk_group=fd_config.model_config.deepseekv3.topk_group,
|
||||
n_group=fd_config.model_config.deepseekv3.n_group,
|
||||
routed_scaling_factor=fd_config.model_config.deepseekv3.
|
||||
routed_scaling_factor,
|
||||
layer_idx=layer_id,
|
||||
weight_key_map=weight_key_map,
|
||||
)
|
||||
|
||||
self.num_shared_experts = fd_config.model_config.deepseekv3.n_shared_experts
|
||||
shared_experts_intermediate_size = (
|
||||
self.num_shared_experts *
|
||||
fd_config.model_config.deepseekv3.moe_intermediate_size)
|
||||
|
||||
self.shared_experts = DeepSeekV3MLP(
|
||||
fd_config=fd_config,
|
||||
intermediate_size=shared_experts_intermediate_size,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
self.fused_moe.load_state_dict(state_dict)
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor):
|
||||
"""
|
||||
"""
|
||||
shared_experts_out = self.shared_experts(hidden_states)
|
||||
moe_out = self.fused_moe(hidden_states)
|
||||
moe_out = moe_out + shared_experts_out
|
||||
# We do to TP all reduce after the sum of experts.
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(moe_out)
|
||||
return moe_out
|
||||
|
||||
|
||||
class DeepseekV3MLAAttention(nn.Layer):
|
||||
"""
|
||||
DeepseekV3MLAAttention
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fd_config: FDConfig,
|
||||
layer_id: int,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.num_attention_heads = fd_config.model_config.num_attention_heads
|
||||
self.num_attention_heads_tp = self.num_attention_heads // self.tp_size
|
||||
|
||||
# MLA
|
||||
self.qk_nope_head_dim = fd_config.model_config.deepseekv3.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = fd_config.model_config.deepseekv3.qk_rope_head_dim
|
||||
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
||||
self.v_head_dim = fd_config.model_config.deepseekv3.v_head_dim
|
||||
self.q_lora_rank = fd_config.model_config.deepseekv3.q_lora_rank
|
||||
self.kv_lora_rank = fd_config.model_config.deepseekv3.kv_lora_rank
|
||||
|
||||
self.attn_softmax_scale = self.qk_head_dim**-0.5
|
||||
self.rope_theta = fd_config.model_config.rope_theta
|
||||
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(fd_config=fd_config,
|
||||
prefix=f"{prefix}.q_a_proj",
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.q_lora_rank,
|
||||
with_bias=False)
|
||||
|
||||
self.q_a_layernorm = RMSNorm(fd_config,
|
||||
hidden_size=self.q_lora_rank,
|
||||
eps=self.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_a_layernorm")
|
||||
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
input_size=self.q_lora_rank,
|
||||
output_size=self.num_attention_heads * self.qk_head_dim,
|
||||
with_bias=False,
|
||||
)
|
||||
else:
|
||||
assert (self.q_lora_rank is not None
|
||||
), "self.q_lora_rank is None, Please Check your config."
|
||||
|
||||
# 不切TP,跑 W4A16 Gemm
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
with_bias=False)
|
||||
|
||||
self.kv_a_layernorm = RMSNorm(fd_config,
|
||||
hidden_size=self.kv_lora_rank,
|
||||
eps=self.rms_norm_eps,
|
||||
prefix=f"{prefix}.kv_a_layernorm")
|
||||
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
input_size=self.kv_lora_rank,
|
||||
output_size=self.num_attention_heads *
|
||||
(self.qk_nope_head_dim + self.v_head_dim),
|
||||
with_bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=self.num_attention_heads *
|
||||
self.v_head_dim,
|
||||
output_size=self.hidden_size,
|
||||
with_bias=False)
|
||||
|
||||
self.kv_b_proj_bmm = KVBatchLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
v_head_dim=self.v_head_dim)
|
||||
|
||||
self.rope_scaling = fd_config.model_config.deepseekv3.rope_scaling
|
||||
if self.rope_scaling:
|
||||
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = self.rope_scaling["factor"]
|
||||
mscale = self.yarn_get_mscale(scaling_factor,
|
||||
float(mscale_all_dim))
|
||||
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
|
||||
|
||||
rope_scaling_kwargs = {
|
||||
key: self.rope_scaling[key]
|
||||
for key in [
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
] if key in self.rope_scaling
|
||||
}
|
||||
self.rope_scaling_factor = self.rope_scaling["factor"]
|
||||
self.rope_scaling_original_max_position_embeddings = self.rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
self.rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.
|
||||
rope_scaling_original_max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
scaling_factor=self.rope_scaling_factor,
|
||||
**rope_scaling_kwargs,
|
||||
)
|
||||
|
||||
self.mla_attn = Attention(
|
||||
fd_config=fd_config,
|
||||
layer_id=layer_id,
|
||||
prefix=prefix,
|
||||
use_neox_rotary_style=False,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
@staticmethod
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
"""
|
||||
"""
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_meta: ForwardMeta,
|
||||
hidden_states: paddle.Tensor,
|
||||
position_ids: paddle.Tensor,
|
||||
mask_encoder_batch: paddle.Tensor,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
layernorm_out = hidden_states
|
||||
fmha_out = paddle.zeros(shape=[
|
||||
layernorm_out.shape[0],
|
||||
self.num_attention_heads_tp * self.v_head_dim
|
||||
],
|
||||
dtype=layernorm_out.dtype)
|
||||
|
||||
decode_stage = forward_meta.is_decode_batch
|
||||
prefill_stage = not (forward_meta.is_decode_batch)
|
||||
|
||||
if prefill_stage:
|
||||
query = self.q_a_proj(layernorm_out)
|
||||
query = self.q_a_layernorm(query)
|
||||
query = self.q_b_proj(query)
|
||||
|
||||
query = query.reshape(
|
||||
[-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||
query_nope, query_pe = query.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
|
||||
compressed_kv, key_pe = compressed_kv.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
|
||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
||||
|
||||
key_value = self.kv_b_proj(compressed_kv)
|
||||
key_value = key_value.reshape([
|
||||
-1, self.num_attention_heads_tp,
|
||||
self.qk_nope_head_dim + self.v_head_dim
|
||||
])
|
||||
key_nope, value = key_value.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], axis=-1)
|
||||
|
||||
query[..., self.qk_nope_head_dim:] = query_pe
|
||||
key = paddle.empty_like(query)
|
||||
key[..., :self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim:] = key_pe
|
||||
value = paddle.nn.functional.pad(
|
||||
value, [0, self.qk_head_dim - self.v_head_dim], value=0)
|
||||
|
||||
fmha_out_prefill = self.mla_attn(q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
qkv=None,
|
||||
compressed_kv=compressed_kv,
|
||||
k_pe=key_pe,
|
||||
forward_meta=forward_meta)
|
||||
|
||||
fmha_out_prefill = fmha_out_prefill.reshape(
|
||||
[-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||
fmha_out_prefill = fmha_out_prefill[:, :, :self.v_head_dim]
|
||||
fmha_out_prefill = fmha_out_prefill.reshape(
|
||||
[-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(
|
||||
fmha_out_prefill.dtype)
|
||||
|
||||
fmha_out = fmha_out + fmha_out_prefill
|
||||
|
||||
if decode_stage:
|
||||
query = self.q_a_proj(layernorm_out)
|
||||
query = self.q_a_layernorm(query)
|
||||
ln_out_or_q_c = query
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
|
||||
compressed_kv, key_pe = compressed_kv.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
|
||||
query = self.q_b_proj(ln_out_or_q_c)
|
||||
query = query.reshape(
|
||||
[-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||
|
||||
query_nope, query_pe = query.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
||||
|
||||
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]),
|
||||
proj_type='k').transpose([1, 0, 2])
|
||||
|
||||
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
|
||||
q_input = q_input.reshape([
|
||||
-1,
|
||||
self.num_attention_heads_tp *
|
||||
(self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
])
|
||||
fmha_out_decode = self.mla_attn(q=q_input,
|
||||
k=None,
|
||||
v=None,
|
||||
qkv=None,
|
||||
compressed_kv=compressed_kv,
|
||||
k_pe=key_pe,
|
||||
forward_meta=forward_meta)
|
||||
|
||||
fmha_out_decode = fmha_out_decode.reshape(
|
||||
[-1, self.num_attention_heads_tp,
|
||||
self.kv_lora_rank]).transpose([1, 0, 2])
|
||||
|
||||
fmha_out_decode = (self.kv_b_proj_bmm(
|
||||
fmha_out_decode, proj_type='v').transpose([1, 0, 2]).reshape(
|
||||
[-1, self.num_attention_heads_tp * self.v_head_dim]))
|
||||
fmha_out = fmha_out + fmha_out_decode
|
||||
|
||||
output = self.o_proj(fmha_out)
|
||||
return output
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
self.q_a_proj.load_state_dict(state_dict)
|
||||
self.q_a_layernorm.load_state_dict(state_dict)
|
||||
self.kv_a_proj_with_mqa.load_state_dict(state_dict)
|
||||
self.kv_a_layernorm.load_state_dict(state_dict)
|
||||
self.q_b_proj.load_state_dict(state_dict)
|
||||
self.kv_b_proj_bmm.load_state_dict(state_dict)
|
||||
self.kv_b_proj.load_state_dict(state_dict)
|
||||
# NOTE(Ryan):Make sure kv_b_proj_bmm loaded before kv_b_proj,
|
||||
# The same weight key will be poped after kv_b_proj.
|
||||
self.o_proj.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class DeepSeekV3DecoderLayer(nn.Layer):
|
||||
"""
|
||||
DeepSeekV3DecoderLayer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_id = int(prefix.split(sep='.')[-1])
|
||||
|
||||
self.self_attn = DeepseekV3MLAAttention(
|
||||
fd_config=fd_config,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
if (fd_config.model_config.deepseekv3.n_routed_experts is not None
|
||||
and layer_id
|
||||
>= fd_config.model_config.deepseekv3.first_k_dense_replace):
|
||||
self.mlp = DeepSeekV3MoE(
|
||||
fd_config=fd_config,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepSeekV3MLP(
|
||||
fd_config=fd_config,
|
||||
intermediate_size=fd_config.model_config.intermediate_size,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
self.self_attn.load_state_dict(state_dict)
|
||||
self.mlp.load_state_dict(state_dict)
|
||||
self.input_layernorm.load_state_dict(state_dict)
|
||||
self.post_attention_layernorm.load_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_meta: ForwardMeta,
|
||||
hidden_states: paddle.Tensor,
|
||||
residual: paddle.Tensor,
|
||||
position_ids: paddle.Tensor,
|
||||
mask_encoder_batch: paddle.Tensor,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(forward_meta, hidden_states,
|
||||
position_ids, mask_encoder_batch)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class DeepSeekV3Model(nn.Layer):
|
||||
"""
|
||||
DeepSeekV3Model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig = None,
|
||||
):
|
||||
"""
|
||||
Initializer for the DeepSeekV3Model class.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
fd_config.model_config.prefix_name = "deepseek_v3"
|
||||
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
fd_config,
|
||||
num_embeddings=fd_config.model_config.vocab_size,
|
||||
embedding_dim=fd_config.model_config.hidden_size,
|
||||
params_dtype=paddle.get_default_dtype(),
|
||||
prefix="deepseek_v3.embed_tokens",
|
||||
)
|
||||
|
||||
self.decoder_layers = nn.LayerList([
|
||||
DeepSeekV3DecoderLayer(
|
||||
fd_config,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix="deepseek_v3.norm",
|
||||
)
|
||||
|
||||
def pre_process(self, forward_meta):
|
||||
"""
|
||||
"""
|
||||
seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
position_ids_shape = paddle.sum(seq_lens_this_time)
|
||||
|
||||
position_ids = paddle.empty(shape=position_ids_shape,
|
||||
dtype=seq_lens_encoder.dtype)
|
||||
mask_encoder_batch = paddle.empty(
|
||||
shape=position_ids_shape,
|
||||
dtype=seq_lens_encoder.dtype).unsqueeze(1)
|
||||
|
||||
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
position_ids,
|
||||
mask_encoder_batch)
|
||||
|
||||
return position_ids, mask_encoder_batch
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
"""
|
||||
self.embeddings.load_state_dict(state_dict)
|
||||
self.norm.load_state_dict(state_dict)
|
||||
for i in range(self.num_layers):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.decoder_layers[i].load_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
|
||||
|
||||
residual = None
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.decoder_layers[i](
|
||||
forward_meta, hidden_states, residual, position_ids,
|
||||
mask_encoder_batch)
|
||||
hidden_states = hidden_states + residual
|
||||
out = self.norm(hidden_states)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
DeepseekV3ForCausalLM
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super().__init__(fd_config)
|
||||
self.model = DeepSeekV3Model(fd_config)
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
fd_config,
|
||||
embedding_dim=fd_config.model_config.hidden_size,
|
||||
num_embeddings=fd_config.model_config.vocab_size,
|
||||
prefix="lm_head",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
"""
|
||||
"""
|
||||
return "DeepseekV3ForCausalLM"
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
"""
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
"""
|
||||
"""
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits[:, self.ori_vocab_size:] = -float("inf")
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
hidden_states = self.model(ids_remove_padding, forward_meta)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
DeepSeekV3PretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||
|
||||
logger.info("DeepseekV3 inference model _get_tensor_parallel_mappings")
|
||||
|
||||
from paddleformers.transformers.conversion_utils import \
|
||||
split_or_merge_func
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
|
||||
def get_tensor_parallel_split_mappings(num_layers):
|
||||
final_actions = {}
|
||||
|
||||
base_actions = {
|
||||
"lm_head.weight": partial(fn, is_column=True),
|
||||
"embed_tokens.weight": partial(fn, is_column=False),
|
||||
"layers.0.self_attn.o_proj.weight": partial(fn,
|
||||
is_column=False),
|
||||
}
|
||||
|
||||
# Self Attention Layer which are need TP.
|
||||
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions[
|
||||
"layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions[
|
||||
"layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial(
|
||||
fn, is_column=True)
|
||||
|
||||
# MLP Layer
|
||||
base_actions["layers.0.mlp.gate_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions["layers.0.mlp.up_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions["layers.0.mlp.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
# Moe Layer
|
||||
for expert_idx in range(config.n_routed_experts):
|
||||
base_actions[
|
||||
f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions[
|
||||
f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions[
|
||||
f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
# Shared Expert Layer
|
||||
base_actions[
|
||||
"layers.0.mlp.shared_experts.up_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions[
|
||||
"layers.0.mlp.shared_experts.gate_proj.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
base_actions[
|
||||
"layers.0.mlp.shared_experts.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
# MTP parts
|
||||
base_actions["layers.61.embed_tokens.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions["layers.61.eh_proj.weight"] = partial(fn,
|
||||
is_column=True)
|
||||
base_actions["layers.61.shared_head.head.weight"] = partial(
|
||||
fn, is_column=True)
|
||||
|
||||
for key, action in base_actions.items():
|
||||
if "layers.0." in key:
|
||||
for i in range(num_layers):
|
||||
final_actions[key.replace("layers.0.",
|
||||
f"layers.{i}.")] = action
|
||||
final_actions[key] = action
|
||||
|
||||
return final_actions
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(config.num_layers)
|
||||
return mappings
|
@@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -37,291 +37,13 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import \
|
||||
LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
logger.info("erine inference model _get_tensor_parallel_mappings")
|
||||
|
||||
from paddleformers.transformers.conversion_utils import \
|
||||
split_or_merge_func
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
|
||||
def gqa_qkv_split_func(
|
||||
weight,
|
||||
tensor_parallel_degree,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
):
|
||||
|
||||
def get_shape(tensor):
|
||||
return (tensor.get_shape()
|
||||
if hasattr(tensor, "get_shape") else tensor.shape)
|
||||
|
||||
def slice_tensor(tensor, start, end):
|
||||
shape = get_shape(tensor)
|
||||
if len(shape) == 1:
|
||||
return tensor[start:end]
|
||||
else:
|
||||
return tensor[..., start:end]
|
||||
|
||||
q_end = num_attention_heads * head_dim
|
||||
k_end = q_end + num_key_value_heads * head_dim
|
||||
v_end = k_end + num_key_value_heads * head_dim
|
||||
|
||||
q = slice_tensor(weight, 0, q_end)
|
||||
k = slice_tensor(weight, q_end, k_end)
|
||||
v = slice_tensor(weight, k_end, v_end)
|
||||
|
||||
def split_tensor(tensor, degree):
|
||||
shape = get_shape(tensor)
|
||||
size = shape[-1]
|
||||
block_size = size // degree
|
||||
if hasattr(tensor, "get_shape"):
|
||||
return [
|
||||
slice_tensor(tensor, i * block_size,
|
||||
(i + 1) * block_size)
|
||||
for i in range(degree)
|
||||
]
|
||||
else:
|
||||
return np.split(tensor, degree, axis=-1)
|
||||
|
||||
q_list = split_tensor(q, tensor_parallel_degree)
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
|
||||
if tensor_parallel_rank is None:
|
||||
return [
|
||||
np.concatenate([q_i, k_i, v_i], axis=-1)
|
||||
for q_i, k_i, v_i in zip(q_list, k_list, v_list)
|
||||
]
|
||||
else:
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
def gqa_qkv_merge_func(weight_list, num_attention_heads,
|
||||
num_key_value_heads, head_dim):
|
||||
tensor_parallel_degree = len(weight_list)
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
||||
|
||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||
|
||||
def get_shape(tensor):
|
||||
return (tensor.get_shape()
|
||||
if hasattr(tensor, "get_shape") else tensor.shape)
|
||||
|
||||
def slice_tensor(tensor, start, end):
|
||||
if len(get_shape(tensor)) == 1:
|
||||
return tensor[start:end]
|
||||
else:
|
||||
return tensor[..., start:end]
|
||||
|
||||
q_list, k_list, v_list = [], [], []
|
||||
|
||||
for weight in weight_list:
|
||||
q_end = num_attention_heads * head_dim
|
||||
k_end = q_end + num_key_value_heads * head_dim
|
||||
v_end = k_end + num_key_value_heads * head_dim
|
||||
|
||||
q = slice_tensor(weight, 0, q_end)
|
||||
k = slice_tensor(weight, q_end, k_end)
|
||||
v = slice_tensor(weight, k_end, v_end)
|
||||
|
||||
q_list.append(q)
|
||||
k_list.append(k)
|
||||
v_list.append(v)
|
||||
|
||||
merged = q_list + k_list + v_list
|
||||
|
||||
if is_paddle_tensor:
|
||||
tensor = paddle.concat(merged, axis=-1)
|
||||
if tensor.place.is_gpu_place():
|
||||
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
|
||||
return tensor
|
||||
else:
|
||||
return np.concatenate(merged, axis=-1)
|
||||
|
||||
if (config.num_key_value_heads is not None
|
||||
and config.num_key_value_heads != config.num_attention_heads):
|
||||
if is_split:
|
||||
qkv_fn = partial(
|
||||
gqa_qkv_split_func,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
else:
|
||||
qkv_fn = partial(
|
||||
gqa_qkv_merge_func,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
else:
|
||||
qkv_fn = partial(fn, is_column=True)
|
||||
|
||||
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
|
||||
moe_num_shared_experts,
|
||||
moe_layer_start_index):
|
||||
|
||||
final_actions = {}
|
||||
|
||||
base_model_prefix = "ernie"
|
||||
base_actions = {
|
||||
"lm_head.weight":
|
||||
partial(fn, is_column=True),
|
||||
# "eh_proj.weight": partial(fn, is_column=True),
|
||||
f"{base_model_prefix}.embed_tokens.weight":
|
||||
partial(fn, is_column=False),
|
||||
}
|
||||
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.weight"] = qkv_fn
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.quant_weight"] = qkv_fn
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.o_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.o_proj.quant_weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.down_proj.weight"] = (
|
||||
partial(fn, is_column=False))
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
for expert_idx in range(moe_num_experts):
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.up_gate_proj.weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.down_proj.quant_weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
if moe_num_shared_experts > 0:
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.up_gate_proj.weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=False, is_naive_2fuse=True)
|
||||
|
||||
for key, action in base_actions.items():
|
||||
if (f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"
|
||||
in key or
|
||||
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"
|
||||
in key
|
||||
or f"{base_model_prefix}.layers.0.mlp.down_proj.weight"
|
||||
in key or
|
||||
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"
|
||||
in key):
|
||||
for i in range(moe_layer_start_index):
|
||||
final_actions[key.replace("layers.0.",
|
||||
f"layers.{i}.")] = action
|
||||
elif f"layers.{moe_layer_start_index}.mlp.experts." in key:
|
||||
for i in range(moe_layer_start_index, num_layers):
|
||||
final_actions[key.replace(
|
||||
f"layers.{moe_layer_start_index}.",
|
||||
f"layers.{i}.")] = action
|
||||
elif f"layers.{moe_layer_start_index}.mlp.shared_experts." in key:
|
||||
for i in range(moe_layer_start_index, num_layers):
|
||||
final_actions[key.replace(
|
||||
f"layers.{moe_layer_start_index}.",
|
||||
f"layers.{i}.")] = action
|
||||
elif f"{base_model_prefix}.layers.0." in key:
|
||||
for i in range(num_layers):
|
||||
final_actions[key.replace("layers.0.",
|
||||
f"layers.{i}.")] = action
|
||||
final_actions[key] = action
|
||||
return final_actions
|
||||
|
||||
moe_num_experts = 0
|
||||
moe_num_shared_experts = 0
|
||||
if isinstance(config.moe_num_experts, list):
|
||||
moe_num_experts = sum(config.moe_num_experts)
|
||||
elif isinstance(config.moe_num_experts, int):
|
||||
moe_num_experts = config.moe_num_experts
|
||||
if hasattr(config, 'moe_num_shared_experts'):
|
||||
moe_num_shared_experts = config.moe_num_shared_experts
|
||||
|
||||
moe_layer_start_index = -1
|
||||
if isinstance(config.moe_layer_start_index, list):
|
||||
moe_layer_start_index = min(config.moe_layer_start_index)
|
||||
elif isinstance(config.moe_layer_start_index, int):
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(
|
||||
config.num_layers,
|
||||
moe_num_experts,
|
||||
moe_num_shared_experts,
|
||||
moe_layer_start_index,
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
class Ernie4_5_MLP(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
@@ -329,6 +51,7 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
fd_config: FDConfig,
|
||||
intermediate_size: int,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
@@ -339,13 +62,12 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
output_size=intermediate_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
input_size=(intermediate_size // self.nranks),
|
||||
input_size=intermediate_size,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
with_bias=False,
|
||||
)
|
||||
@@ -423,8 +145,8 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
f"{prefix}.experts.{{}}.down_proj.code_zp",
|
||||
}
|
||||
elif moe_quant_type == "tensor_wise_fp8" or (
|
||||
moe_quant_type == "block_wise_fp8" and
|
||||
fd_config.model_config.is_quantized):
|
||||
moe_quant_type == "block_wise_fp8"
|
||||
and fd_config.model_config.is_quantized):
|
||||
weight_key_map = {
|
||||
"gate_weight_key":
|
||||
f"{prefix}.gate.weight",
|
||||
@@ -492,8 +214,6 @@ class Ernie4_5_Attention(nn.Layer):
|
||||
prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
@@ -502,8 +222,8 @@ class Ernie4_5_Attention(nn.Layer):
|
||||
self.o_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=(fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads // nranks),
|
||||
input_size=fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
self.attn = Attention(
|
||||
@@ -636,12 +356,12 @@ class Ernie4_5_Model(nn.Layer):
|
||||
params_dtype=paddle.get_default_dtype(),
|
||||
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"))
|
||||
|
||||
self.hidden_layers = [
|
||||
self.hidden_layers = nn.LayerList([
|
||||
Ernie4_5_DecoderLayer(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
@@ -772,3 +492,134 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
Model Architecture Name
|
||||
"""
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
||||
|
||||
class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
weight_infos = [
|
||||
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
|
||||
True, tsm.GQA),
|
||||
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
|
||||
False),
|
||||
WeightMeta(".embed_tokens.weight", False),
|
||||
WeightMeta("lm_head.weight", True),
|
||||
# quant tensorwise
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight",
|
||||
True, tsm.GQA),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight",
|
||||
False),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
logger.info("erine inference model _get_tensor_parallel_mappings")
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
build_expanded_keys, has_prefix, split_or_merge_func_v1)
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim)
|
||||
|
||||
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
|
||||
moe_layer_start_index,
|
||||
prefix_name):
|
||||
base_actions = {}
|
||||
weight_infos = cls.weight_infos
|
||||
for (weight_name, is_column, extra) in weight_infos:
|
||||
params = {
|
||||
"is_column": is_column,
|
||||
**({
|
||||
extra.value: True
|
||||
} if extra else {})
|
||||
}
|
||||
|
||||
if "lm_head.weight" in weight_name:
|
||||
key = weight_name
|
||||
elif not has_prefix(prefix_name, weight_name):
|
||||
key = f"{prefix_name}{weight_name}"
|
||||
else:
|
||||
key = weight_name
|
||||
base_actions[key] = partial(fn, **params)
|
||||
final_actions = {}
|
||||
start_layer = (moe_layer_start_index
|
||||
if moe_layer_start_index > 0 else num_layers)
|
||||
final_actions = build_expanded_keys(
|
||||
num_layers,
|
||||
moe_num_experts,
|
||||
start_layer,
|
||||
base_actions,
|
||||
)
|
||||
return final_actions
|
||||
|
||||
moe_num_experts = 0
|
||||
if isinstance(config.moe_num_experts, list):
|
||||
moe_num_experts = sum(config.moe_num_experts)
|
||||
elif isinstance(config.moe_num_experts, int):
|
||||
moe_num_experts = config.moe_num_experts
|
||||
|
||||
moe_layer_start_index = -1
|
||||
if isinstance(config.moe_layer_start_index, list):
|
||||
moe_layer_start_index = min(config.moe_layer_start_index)
|
||||
elif isinstance(config.moe_layer_start_index, int):
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(config.num_layers,
|
||||
moe_num_experts,
|
||||
moe_layer_start_index,
|
||||
config.prefix_name)
|
||||
return mappings
|
||||
|
@@ -26,7 +26,7 @@ from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
@@ -265,12 +265,12 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
self.embeddings = fd_config.speculative_config.sharing_model.model.embeddings
|
||||
|
||||
self.hidden_layers = [
|
||||
self.hidden_layers = nn.LayerList([
|
||||
Ernie4_5_DecoderLayer(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.{i}")
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
])
|
||||
|
||||
self.enorm = RMSNorm(
|
||||
fd_config,
|
||||
@@ -286,7 +286,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
prefix="ernie.mtp_hidden_norm.0",
|
||||
)
|
||||
|
||||
self.eh_proj = ParallelLMHead(
|
||||
self.eh_proj = ParallelEHProjection(
|
||||
fd_config=fd_config,
|
||||
num_embeddings=fd_config.model_config.hidden_size,
|
||||
embedding_dim=fd_config.model_config.hidden_size * 2,
|
||||
|
@@ -25,6 +25,8 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
@@ -66,6 +68,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
moe_layer_start_index = fd_config.moe_config.moe_layer_start_index
|
||||
if isinstance(moe_layer_start_index, int):
|
||||
text_moe_layer_start_index = moe_layer_start_index
|
||||
@@ -99,6 +102,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
}
|
||||
self.mlp_text = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
reduce_results=False,
|
||||
moe_intermediate_size=fd_config.moe_config.
|
||||
moe_intermediate_size[0],
|
||||
num_experts=fd_config.moe_config.num_experts[0],
|
||||
@@ -130,6 +134,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
}
|
||||
self.mlp_image = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
reduce_results=False,
|
||||
moe_intermediate_size=fd_config.moe_config.
|
||||
moe_intermediate_size[1],
|
||||
num_experts=fd_config.moe_config.num_experts[1],
|
||||
@@ -154,6 +159,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
intermediate_size=self.num_shared_experts *
|
||||
fd_config.moe_config.moe_intermediate_size[0],
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
def extract_gate_correction_bias_text(self, gate_correction_bias_key,
|
||||
@@ -210,6 +216,8 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
hidden_states = self.mlp_text(hidden_states)
|
||||
if self.num_shared_experts > 0:
|
||||
hidden_states += share_experts_out
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -337,12 +345,12 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"),
|
||||
)
|
||||
|
||||
self.hidden_layers = [
|
||||
self.hidden_layers = nn.LayerList([
|
||||
Ernie4_5_VLDecoderLayer(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
|
@@ -29,6 +29,7 @@ class ModelRegistry:
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class):
|
||||
"""register model class"""
|
||||
if issubclass(
|
||||
model_class,
|
||||
ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
@@ -37,6 +38,7 @@ class ModelRegistry:
|
||||
|
||||
@classmethod
|
||||
def get_class(cls, name):
|
||||
"""get model class"""
|
||||
if name not in cls._registry:
|
||||
raise ValueError(f"Model '{name}' is not registered!")
|
||||
return cls._registry[name]
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -55,13 +55,12 @@ class Qwen2MLP(nn.Layer):
|
||||
output_size=fd_config.model_config.ffn_hidden_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
input_size=(fd_config.model_config.ffn_hidden_size // self.nranks),
|
||||
input_size=fd_config.model_config.ffn_hidden_size,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
with_bias=False,
|
||||
)
|
||||
@@ -97,8 +96,6 @@ class Qwen2Attention(nn.Layer):
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(fd_config=fd_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
with_bias=True)
|
||||
@@ -106,7 +103,7 @@ class Qwen2Attention(nn.Layer):
|
||||
self.o_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=(fd_config.model_config.hidden_size // nranks),
|
||||
input_size=fd_config.model_config.hidden_size,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
|
||||
@@ -305,6 +302,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
super(Qwen2ForCausalLM, self).__init__(fd_config)
|
||||
|
||||
self.fd_config =fd_config
|
||||
self.model = Qwen2Model(fd_config=fd_config)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -68,7 +68,7 @@ class Qwen3Attention(nn.Layer):
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads // nranks,
|
||||
fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig, ModelConfig
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import \
|
||||
support_graph_optimization
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
@@ -57,13 +57,12 @@ class Qwen3MLP(nn.Layer):
|
||||
output_size=fd_config.model_config.ffn_hidden_size * 2,
|
||||
with_bias=False,
|
||||
activation=fd_config.model_config.hidden_act,
|
||||
use_fast_ffn=True,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
fd_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
input_size=(fd_config.model_config.ffn_hidden_size // self.nranks),
|
||||
input_size=fd_config.model_config.ffn_hidden_size,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
with_bias=False,
|
||||
)
|
||||
@@ -111,7 +110,7 @@ class Qwen3Attention(nn.Layer):
|
||||
fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads // nranks,
|
||||
fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
|
||||
|
405
fastdeploy/model_executor/models/tp_utils.py
Normal file
405
fastdeploy/model_executor/models/tp_utils.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.transformers.conversion_utils import split_or_merge_func
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder
|
||||
|
||||
|
||||
def check_tensor_parallel_prerequisites(
|
||||
fd_config: FDConfig,
|
||||
cls: PretrainedModel,
|
||||
tensor_parallel_filtered_map: Dict[str, partial],
|
||||
safetensor_keys: List[str],
|
||||
) -> None:
|
||||
"""check_tensor_parallel_prerequisites"""
|
||||
if fd_config.parallel_config.tensor_parallel_degree > 1:
|
||||
tensor_parallel_map = cls._get_tensor_parallel_mappings(
|
||||
fd_config.model_config, is_split=True)
|
||||
if not tensor_parallel_map:
|
||||
logger.error("filtered_quant_map should not be empty. \
|
||||
parallel splitting required, but _get_tensor_parallel_mappings is not implemented."
|
||||
)
|
||||
filtered_tp_keys = cls._resolve_prefix_keys(tensor_parallel_map.keys(),
|
||||
safetensor_keys)
|
||||
for k, v in filtered_tp_keys.items():
|
||||
tensor_parallel_filtered_map[v] = tensor_parallel_map.pop(k)
|
||||
if not tensor_parallel_filtered_map:
|
||||
logger.error("tensor_parallel_filtered_map should not be empty. \
|
||||
The weights required for tensor parallel splitting are inconsistent with the model's weights."
|
||||
)
|
||||
|
||||
|
||||
def extract_prefix(weight_name: str) -> str:
|
||||
"""extract_prefix"""
|
||||
if weight_name.startswith("."):
|
||||
return ""
|
||||
parts = weight_name.split(".", 1)
|
||||
return parts[0] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
def has_prefix(prefix_name: str, weight_name: str):
|
||||
"""has_prefix"""
|
||||
return prefix_name == extract_prefix(weight_name)
|
||||
|
||||
|
||||
class TensorSplitMode(Enum):
|
||||
"""TensorSplitMode"""
|
||||
|
||||
GQA = "is_gqa"
|
||||
TRANSPOSE = "transpose"
|
||||
QKV = "is_old_qkv"
|
||||
PairFused = "is_naive_2fuse"
|
||||
TripletFused = "is_naive_3fuse"
|
||||
|
||||
|
||||
def extract_placeholders(template: str):
|
||||
"""extract_placeholders"""
|
||||
return set(re.findall(r"{(\w+)}", template))
|
||||
|
||||
|
||||
class SafeDict(dict):
|
||||
"""SafeDict"""
|
||||
|
||||
def __missing__(self, key):
|
||||
return "{" + key + "}"
|
||||
|
||||
|
||||
def has_placeholders(placeholders):
|
||||
"""has_placeholders"""
|
||||
return len(placeholders) > 0
|
||||
|
||||
|
||||
def update_final_actions(params, final_actions, key, action):
|
||||
"""update_final_actions"""
|
||||
new_key = key.format_map(SafeDict(params))
|
||||
final_actions[new_key] = action
|
||||
|
||||
|
||||
def build_expanded_keys(num_layers, num_experts, start_layer, base_actions):
|
||||
"""build_expanded_keys"""
|
||||
final_actions = {}
|
||||
for key, action in base_actions.items():
|
||||
placeholders = extract_placeholders(key)
|
||||
if not has_placeholders(placeholders):
|
||||
final_actions[key] = action
|
||||
else:
|
||||
if LayerIdPlaceholder.LAYER_ID.value in placeholders:
|
||||
for layer_id in range(num_layers):
|
||||
update_final_actions(
|
||||
{LayerIdPlaceholder.LAYER_ID.value: layer_id},
|
||||
final_actions,
|
||||
key,
|
||||
action,
|
||||
)
|
||||
elif LayerIdPlaceholder.FFN_LAYER_ID.value in placeholders:
|
||||
for layer_id in range(start_layer):
|
||||
update_final_actions(
|
||||
{LayerIdPlaceholder.FFN_LAYER_ID.value: layer_id},
|
||||
final_actions,
|
||||
key,
|
||||
action,
|
||||
)
|
||||
elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders
|
||||
and LayerIdPlaceholder.EXPERT_ID.value in placeholders):
|
||||
for layer_id in range(start_layer, num_layers):
|
||||
for export_id in range(num_experts):
|
||||
update_final_actions(
|
||||
{
|
||||
LayerIdPlaceholder.MOE_LAYER_ID.value:
|
||||
layer_id,
|
||||
LayerIdPlaceholder.EXPERT_ID.value: export_id,
|
||||
},
|
||||
final_actions,
|
||||
key,
|
||||
action,
|
||||
)
|
||||
elif (LayerIdPlaceholder.MOE_LAYER_ID.value in placeholders
|
||||
and len(placeholders) == 1):
|
||||
for layer_id in range(start_layer, num_layers):
|
||||
update_final_actions(
|
||||
{LayerIdPlaceholder.MOE_LAYER_ID.value: layer_id},
|
||||
final_actions,
|
||||
key,
|
||||
action,
|
||||
)
|
||||
else:
|
||||
logger.error(f"{key} does not match any case.")
|
||||
return final_actions
|
||||
|
||||
|
||||
def gqa_qkv_split_func(
|
||||
tensor_parallel_degree,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
):
|
||||
"""
|
||||
gqa_qkv_split_func
|
||||
"""
|
||||
|
||||
def fn(x, is_column=True):
|
||||
"""fucn"""
|
||||
|
||||
def get_shape(tensor):
|
||||
"""get_shape"""
|
||||
return tensor.get_shape() if hasattr(tensor,
|
||||
"get_shape") else tensor.shape
|
||||
|
||||
def slice_tensor(tensor, start, end):
|
||||
"""slice_tensor"""
|
||||
shape = get_shape(tensor)
|
||||
if len(shape) == 1:
|
||||
return tensor[start:end]
|
||||
elif is_column:
|
||||
return tensor[..., start:end]
|
||||
else:
|
||||
return tensor[start:end, ...]
|
||||
|
||||
q_end = num_attention_heads * head_dim
|
||||
k_end = q_end + num_key_value_heads * head_dim
|
||||
v_end = k_end + num_key_value_heads * head_dim
|
||||
|
||||
q = slice_tensor(x, 0, q_end)
|
||||
k = slice_tensor(x, q_end, k_end)
|
||||
v = slice_tensor(x, k_end, v_end)
|
||||
|
||||
def split_tensor(tensor, degree):
|
||||
"""
|
||||
split_tensor
|
||||
"""
|
||||
shape = get_shape(tensor)
|
||||
size = shape[-1] if is_column else shape[0]
|
||||
block_size = size // degree
|
||||
if hasattr(tensor, "get_shape"):
|
||||
return [
|
||||
slice_tensor(tensor, i * block_size, (i + 1) * block_size)
|
||||
for i in range(degree)
|
||||
]
|
||||
else:
|
||||
if isinstance(x, paddle.Tensor):
|
||||
if is_column:
|
||||
return paddle.split(tensor, degree, axis=-1)
|
||||
else:
|
||||
return paddle.split(tensor, degree, axis=0)
|
||||
else:
|
||||
if is_column:
|
||||
return np.split(tensor, degree, axis=-1)
|
||||
else:
|
||||
return np.split(tensor, degree, axis=0)
|
||||
|
||||
q_list = split_tensor(q, tensor_parallel_degree)
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
|
||||
if tensor_parallel_rank is None:
|
||||
res = []
|
||||
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
|
||||
if is_column:
|
||||
if isinstance(x, paddle.Tensor):
|
||||
res.append(paddle.concat([q_i, k_i, v_i], axis=-1))
|
||||
else:
|
||||
res.append(np.concatenate([q_i, k_i, v_i], axis=-1))
|
||||
else:
|
||||
if isinstance(x, paddle.Tensor):
|
||||
res.append(paddle.concat([q_i, k_i, v_i], axis=0))
|
||||
else:
|
||||
res.append(np.concatenate([q_i, k_i, v_i], axis=0))
|
||||
return res
|
||||
else:
|
||||
if isinstance(x, paddle.Tensor):
|
||||
if is_column:
|
||||
return paddle.concat(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
return paddle.concat(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
else:
|
||||
if is_column:
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
||||
"""
|
||||
gqa_qkv_merge_func
|
||||
"""
|
||||
|
||||
def fn(weight_list, is_column=True):
|
||||
"""fn"""
|
||||
tensor_parallel_degree = len(weight_list)
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
||||
|
||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||
|
||||
def get_shape(tensor):
|
||||
"""
|
||||
get_shape
|
||||
"""
|
||||
return tensor.get_shape() if hasattr(tensor,
|
||||
"get_shape") else tensor.shape
|
||||
|
||||
def slice_tensor(tensor, start, end):
|
||||
"""
|
||||
slice_tensor
|
||||
"""
|
||||
if len(get_shape(tensor)) == 1:
|
||||
return tensor[start:end]
|
||||
elif is_column:
|
||||
return tensor[..., start:end]
|
||||
else:
|
||||
return tensor[start:end, ...]
|
||||
|
||||
q_list, k_list, v_list = [], [], []
|
||||
|
||||
for weight in weight_list:
|
||||
q_end = num_attention_heads * head_dim
|
||||
k_end = q_end + num_key_value_heads * head_dim
|
||||
v_end = k_end + num_key_value_heads * head_dim
|
||||
|
||||
q = slice_tensor(weight, 0, q_end)
|
||||
k = slice_tensor(weight, q_end, k_end)
|
||||
v = slice_tensor(weight, k_end, v_end)
|
||||
|
||||
q_list.append(q)
|
||||
k_list.append(k)
|
||||
v_list.append(v)
|
||||
|
||||
merged = q_list + k_list + v_list
|
||||
|
||||
if is_paddle_tensor:
|
||||
if is_column:
|
||||
tensor = paddle.concat(merged, axis=-1)
|
||||
else:
|
||||
tensor = paddle.concat(merged, axis=0)
|
||||
if tensor.place.is_gpu_place():
|
||||
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
|
||||
return tensor
|
||||
else:
|
||||
if is_column:
|
||||
return np.concatenate(merged, axis=-1)
|
||||
else:
|
||||
return np.concatenate(merged, axis=0)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def split_or_merge_qkv_func(
|
||||
is_split,
|
||||
tensor_parallel_degree,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
):
|
||||
"""
|
||||
split_or_merge_qkv_func
|
||||
"""
|
||||
if is_split:
|
||||
return gqa_qkv_split_func(
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
else:
|
||||
return gqa_qkv_merge_func(
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
|
||||
def split_or_merge_func_v1(
|
||||
is_split,
|
||||
tensor_parallel_degree,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads=None,
|
||||
num_key_value_heads=None,
|
||||
head_dim=None,
|
||||
):
|
||||
"""
|
||||
split_or_merge_func_v1
|
||||
"""
|
||||
|
||||
def fn(x, **kwargs):
|
||||
"""func"""
|
||||
is_gqa = kwargs.pop("is_gqa", False)
|
||||
if is_gqa:
|
||||
func = split_or_merge_qkv_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
is_column = kwargs.pop("is_column", True)
|
||||
return func(x, is_column=is_column)
|
||||
else:
|
||||
func = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
is_column = kwargs.pop("is_column", True)
|
||||
is_naive_2fuse = kwargs.pop("is_naive_2fuse", False)
|
||||
return func(x, is_column=is_column, is_naive_2fuse=is_naive_2fuse)
|
||||
|
||||
return fn
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
@@ -23,29 +24,47 @@ import random
|
||||
import re
|
||||
import struct
|
||||
from functools import partial
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddle.common_ops_import import convert_dtype
|
||||
from paddle.distributed import fleet
|
||||
from paddleformers.transformers.model_utils import (_add_variant,
|
||||
load_tp_checkpoint)
|
||||
from paddleformers.transformers.model_utils import _add_variant
|
||||
from paddleformers.transformers.utils import paddleformers_load
|
||||
from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_MASTER_WEIGHTS_INDEX_NAME,
|
||||
SAFE_PEFT_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME)
|
||||
from paddleformers.utils.log import logger
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
|
||||
MAX_BSZ = 512
|
||||
MAX_DRAFT_TOKENS = 6
|
||||
|
||||
|
||||
class LayerIdPlaceholder(str, enum.Enum):
|
||||
"""LayerIdPlaceholder"""
|
||||
LAYER_ID = "layer_id"
|
||||
FFN_LAYER_ID = "ffn_layer_id"
|
||||
MOE_LAYER_ID = "moe_layer_id"
|
||||
EXPERT_ID = "export_id"
|
||||
|
||||
|
||||
class WeightMeta(NamedTuple):
|
||||
"""
|
||||
#tensor split parameters
|
||||
|
||||
# weight_name: weight name
|
||||
# is_column: whether to split by columns
|
||||
# extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse"
|
||||
"""
|
||||
weight_name: str
|
||||
is_column: bool
|
||||
extra: Optional[str] = None
|
||||
|
||||
|
||||
class UniqueIDGenerator:
|
||||
"""
|
||||
The generator for the export model id
|
||||
@@ -433,223 +452,6 @@ def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
|
||||
return total_effective_tokens, total_tokens
|
||||
|
||||
|
||||
def load_ep_checkpoint(model_path: str,
|
||||
config: ModelConfig,
|
||||
return_numpy: bool = False,
|
||||
return_key_name: bool = True):
|
||||
"""
|
||||
load ep checkpoint
|
||||
"""
|
||||
# return_numpy=True cpu
|
||||
# return_numpy=False gpu
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"),
|
||||
"r") as f:
|
||||
weight_list = json.load(f)["weight_map"]
|
||||
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
|
||||
num_local_ffn_keys = []
|
||||
|
||||
for i in range(config.moe_layer_start_index, config.num_layers):
|
||||
for j in range(
|
||||
config.num_experts_start_offset,
|
||||
config.num_experts_start_offset + config.num_experts_per_rank,
|
||||
):
|
||||
ffn1_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
|
||||
ffn2_key = (
|
||||
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
|
||||
|
||||
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
|
||||
ffn2_quant_key = (
|
||||
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
|
||||
|
||||
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
|
||||
ffn2_scale_key = (
|
||||
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
|
||||
num_local_ffn_keys.append(ffn1_key)
|
||||
num_local_ffn_keys.append(ffn2_key)
|
||||
num_local_ffn_keys.append(ffn1_quant_key)
|
||||
num_local_ffn_keys.append(ffn2_quant_key)
|
||||
num_local_ffn_keys.append(ffn1_scale_key)
|
||||
num_local_ffn_keys.append(ffn2_scale_key)
|
||||
|
||||
for k in num_local_ffn_keys:
|
||||
if k in weight_list:
|
||||
filtered_map[k] = weight_list[k]
|
||||
|
||||
state_dict = {}
|
||||
# Get all safetensor file paths that need to be opened
|
||||
safetensor_paths = set(filtered_map.values())
|
||||
|
||||
# Open each safetensor file sequentially with progress bar
|
||||
for safetensor_path in tqdm(safetensor_paths,
|
||||
desc="Loading safetensor files",
|
||||
unit="file"):
|
||||
with safe_open(os.path.join(model_path, safetensor_path),
|
||||
framework="np",
|
||||
device="cpu") as f:
|
||||
# Check if this file contains keys from filtered_map
|
||||
for k in filtered_map:
|
||||
if filtered_map[k] == safetensor_path and k in f.keys():
|
||||
weight = f.get_tensor(k)
|
||||
if not return_numpy:
|
||||
weight = paddle.Tensor(weight, zero_copy=True)
|
||||
weight = weight._copy_to(
|
||||
paddle.framework._current_expected_place(), False)
|
||||
state_dict[k] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_safetensor_file(model_path):
|
||||
"""
|
||||
get_safetensor_file
|
||||
"""
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"),
|
||||
"r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
os.path.join(model_path, weight_map[weight_name]))
|
||||
key_name_list = list(set(weight_map.keys()))
|
||||
safetensor_list = list(weight_files_in_index)
|
||||
safetensor_list.sort()
|
||||
return key_name_list, safetensor_list
|
||||
|
||||
|
||||
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
|
||||
"""
|
||||
safetensors_weights_iterator
|
||||
"""
|
||||
for st_file in tqdm(
|
||||
safe_tensor_list,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
):
|
||||
with safe_open(st_file, framework="np") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(safetensor_list: list[str]):
|
||||
"""
|
||||
fastsafetensors_weights_iterator
|
||||
"""
|
||||
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||
world_size = dist.get_world_size()
|
||||
if world_size > 1:
|
||||
dist.init_parallel_env()
|
||||
pg = dist.get_group()
|
||||
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
|
||||
) else "cpu"
|
||||
|
||||
safetensor_files_sub_lists = [
|
||||
safetensor_list[i:i + world_size]
|
||||
for i in range(0, len(safetensor_list), world_size)
|
||||
]
|
||||
for st_file in tqdm(
|
||||
safetensor_files_sub_lists,
|
||||
desc="Loading fastsafetensors checkpoint shards",
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg,
|
||||
device,
|
||||
nogds=True,
|
||||
debug_log=False,
|
||||
framework="paddle")
|
||||
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
try:
|
||||
keys = list(fb.key_to_rank_lidx.keys())
|
||||
for k in keys:
|
||||
t = fb.get_tensor(k)
|
||||
yield k, t
|
||||
finally:
|
||||
fb.close()
|
||||
finally:
|
||||
loader.close()
|
||||
|
||||
|
||||
def get_state_dict(model_path, config, use_fastsafetensor=False):
|
||||
"""
|
||||
get_state_dict
|
||||
"""
|
||||
state_dict = {}
|
||||
_, safetensor_list = get_safetensor_file(
|
||||
os.path.join(model_path, f"rank{config.tensor_parallel_rank}"))
|
||||
if use_fastsafetensor:
|
||||
weights_iterator = fastsafetensors_weights_iterator(safetensor_list)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_list)
|
||||
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def apply_quant(name_action_quant_mappings, key, tensor, state_dict):
|
||||
"""
|
||||
apply_quant
|
||||
"""
|
||||
if key in name_action_quant_mappings:
|
||||
action = name_action_quant_mappings.pop(key)
|
||||
quant_weight_tensor, weight_scale_tensor = action(tensor)
|
||||
if quant_weight_tensor is not None and weight_scale_tensor is not None:
|
||||
state_dict[key + ".quant_weight"] = quant_weight_tensor
|
||||
state_dict[key + ".weight_scale"] = weight_scale_tensor
|
||||
else:
|
||||
state_dict[key] = quant_weight_tensor
|
||||
else:
|
||||
state_dict[key] = tensor
|
||||
|
||||
|
||||
def load_checkpoint(model_path, cls, config, return_numpy=True, load_gpu=True):
|
||||
"""
|
||||
load checkpoint
|
||||
"""
|
||||
if getattr(config, "parallel_config", None) is not None:
|
||||
use_ep = getattr(config.parallel_config, "use_ep", False)
|
||||
tensor_parallel_degree = config.parallel_config.tensor_parallel_degree
|
||||
else:
|
||||
use_ep = getattr(config, "use_ep", False)
|
||||
tensor_parallel_degree = config.tensor_parallel_degree
|
||||
|
||||
if getattr(config, "model_config", None) is not None:
|
||||
model_config = config.model_config
|
||||
else:
|
||||
model_config = config
|
||||
|
||||
if use_ep:
|
||||
state_dict = load_ep_checkpoint(model_path,
|
||||
config,
|
||||
return_numpy=True,
|
||||
return_key_name=True)
|
||||
else:
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank")
|
||||
and os.path.isdir(os.path.join(model_path, f))
|
||||
]
|
||||
if len(rank_dirs) > 1:
|
||||
if tensor_parallel_degree != len(rank_dirs):
|
||||
raise ValueError(
|
||||
f"Your model only supports loading with tp{len(rank_dirs)}"
|
||||
)
|
||||
state_dict = get_state_dict(model_path, model_config)
|
||||
else:
|
||||
state_dict = load_tp_checkpoint(model_path,
|
||||
cls,
|
||||
model_config,
|
||||
return_numpy=return_numpy)
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
match = re.search(r'layers\.(\d+)', k)
|
||||
if match and int(match.group(1)) > 0:
|
||||
continue
|
||||
return state_dict
|
||||
|
||||
|
||||
def parser_quant_type(quant_type):
|
||||
"""
|
||||
Parse the quantization type string and return the corresponding quantization types for weights,
|
||||
|
@@ -13,10 +13,19 @@
|
||||
# limitations under the License.
|
||||
"""fastdeploy gpu ops"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from fastdeploy.import_ops import import_custom_ops
|
||||
|
||||
PACKAGE = "fastdeploy.model_executor.ops.gpu"
|
||||
|
||||
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
|
||||
def tolerant_import_error():
|
||||
class NoneModule:
|
||||
def __getattr__(self, name):
|
||||
return None
|
||||
|
||||
sys.modules[__name__] = NoneModule()
|
||||
|
354
fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py
Normal file
354
fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import triton
|
||||
|
||||
from .triton_utils import (SubstituteTemplate, build_package, compile_file,
|
||||
extract_triton_kernel, find_so_path,
|
||||
get_pointer_hint, link_file, multi_process_do,
|
||||
python_path, rename_c_to_cu)
|
||||
|
||||
|
||||
def get_value_hint(x):
|
||||
"""
|
||||
Get the value hint from input list.
|
||||
"""
|
||||
hint = ""
|
||||
for ele in x:
|
||||
if isinstance(ele, int):
|
||||
hint += "i64,"
|
||||
continue
|
||||
if ele % 16 == 0 and ele > 0:
|
||||
hint += "i64:16,"
|
||||
elif ele == 1:
|
||||
hint += "i64:1,"
|
||||
else:
|
||||
hint += "i64,"
|
||||
if isinstance(ele, float):
|
||||
hint += "fp32,"
|
||||
return hint
|
||||
|
||||
|
||||
common_template = ("""
|
||||
#include "${op_name}_kernel.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void ${op_name}_func(${tensor_and_attr}) {
|
||||
auto run_stream = a_ptr->stream();
|
||||
auto res_flag = ${op_name}_kernel(run_stream, ${triton_kernel_args}, 0);
|
||||
if (res_flag == CUDA_ERROR_INVALID_VALUE) {
|
||||
PD_THROW("${op_name}_kernel failed");
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(${op_name}_package, m) {
|
||||
|
||||
m.def("${op_name}_func", ${op_name}_func, "get expert token num");
|
||||
}
|
||||
|
||||
""")
|
||||
|
||||
|
||||
class KernelInterface:
|
||||
"""
|
||||
triton kernel interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func,
|
||||
other_config,
|
||||
key_args=["1"],
|
||||
):
|
||||
"""
|
||||
triton kernel interface.
|
||||
"""
|
||||
self.func = func
|
||||
self.key_args = key_args
|
||||
|
||||
signature = inspect.signature(func)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
for ele in self.arg_names:
|
||||
assert self.arg_names.count(ele) == 1
|
||||
# arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
# self.annotations = {
|
||||
# name: ty for name, ty in func.__annotations__.items()
|
||||
# }
|
||||
self.annotations = dict(func.__annotations__)
|
||||
|
||||
self.constexprs = [
|
||||
self.arg_names.index(name) for name in self.arg_names
|
||||
if self.annotations.get(name) == triton.language.core.constexpr
|
||||
]
|
||||
|
||||
self.arg_exclude_constexpr = [
|
||||
self.arg_names[i] for i in range(len(self.arg_names))
|
||||
if i not in self.constexprs
|
||||
]
|
||||
|
||||
import textwrap
|
||||
|
||||
py_script = textwrap.dedent(inspect.getsource(func))
|
||||
|
||||
pat = r"def\s" + func.__name__
|
||||
func_begin = re.findall(pat, py_script)
|
||||
assert len(func_begin) == 1
|
||||
func_begin = func_begin[0]
|
||||
py_script = py_script[py_script.find(func_begin):]
|
||||
|
||||
self.func_map = {}
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
"""
|
||||
decorator for triton kernels.
|
||||
Args:
|
||||
*args: positional arguments
|
||||
**kwargs: keyword arguments
|
||||
"""
|
||||
op_name = "haha" + str(kwargs["N"])
|
||||
if op_name in self.func_map.keys():
|
||||
return self.func_map[op_name](*args)
|
||||
|
||||
all_input = []
|
||||
|
||||
for i in range(len(args)):
|
||||
all_input.append(args[i])
|
||||
|
||||
position_arguments_num = len(all_input)
|
||||
for i in range(position_arguments_num, len(self.arg_names)):
|
||||
if self.arg_names[i] in kwargs.keys():
|
||||
all_input.append(kwargs[self.arg_names[i]])
|
||||
else:
|
||||
# means this input is not specified, it muse be a tl.constexpr.
|
||||
assert i in self.constexprs
|
||||
all_input.append(None)
|
||||
|
||||
dtypes = []
|
||||
x_list = []
|
||||
const_args = [self.arg_names[i] for i in self.constexprs]
|
||||
|
||||
decalare_arg_exclude_constexpr = list(self.arg_exclude_constexpr)
|
||||
passed_arg_exclude_constexpr = list(self.arg_exclude_constexpr)
|
||||
|
||||
const_hint_dict = {}
|
||||
for i in range(len(all_input)):
|
||||
ele = all_input[i]
|
||||
|
||||
if type(ele) in [
|
||||
paddle.Tensor, paddle.base.framework.EagerParamBase,
|
||||
paddle.base.framework.Parameter,
|
||||
paddle.base.framework.Variable,
|
||||
paddle.base.libpaddle.pir.Value,
|
||||
type(None)
|
||||
]:
|
||||
if ele is not None:
|
||||
dtypes.append(ele.dtype)
|
||||
passed_arg_exclude_constexpr[
|
||||
i] = f"(CUdeviceptr)({passed_arg_exclude_constexpr[i]}->data())"
|
||||
else:
|
||||
dtypes.append(paddle.int8)
|
||||
passed_arg_exclude_constexpr[
|
||||
i] = "(CUdeviceptr)(nullptr)"
|
||||
decalare_arg_exclude_constexpr[
|
||||
i] = "const paddle::optional<paddle::Tensor>&" + decalare_arg_exclude_constexpr[
|
||||
i]
|
||||
elif i in self.constexprs:
|
||||
if isinstance(ele, bool):
|
||||
const_hint_dict[self.arg_names[i]] = (int)(ele)
|
||||
elif isinstance(ele, int):
|
||||
if ele < 0:
|
||||
const_hint_dict[self.arg_names[i]] = 0
|
||||
else:
|
||||
const_hint_dict[self.arg_names[i]] = ele
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
x_list.append(ele)
|
||||
if isinstance(ele, int):
|
||||
decalare_arg_exclude_constexpr[
|
||||
i] = "const int64_t " + decalare_arg_exclude_constexpr[
|
||||
i]
|
||||
elif isinstance(ele, float):
|
||||
decalare_arg_exclude_constexpr[
|
||||
i] = "const float " + decalare_arg_exclude_constexpr[
|
||||
i]
|
||||
else:
|
||||
assert False
|
||||
|
||||
python_package_name = f"{op_name}_package"
|
||||
tp_rank = paddle.distributed.get_rank()
|
||||
|
||||
generated_dir = os.getenv("TRITON_KERNEL_CACHE_DIR",
|
||||
f"/tmp/triton_cache/rank{tp_rank}")
|
||||
print("the kernel cache dir is:", generated_dir)
|
||||
generated_dir = f"{generated_dir}/{op_name}"
|
||||
os.makedirs(generated_dir, exist_ok=True)
|
||||
|
||||
py_script_file = f"{generated_dir}/triton_kernels.py"
|
||||
extract_triton_kernel(func, py_script_file)
|
||||
|
||||
address_hint = get_pointer_hint(dtypes)
|
||||
value_hint = get_value_hint(x_list)
|
||||
const_args = [f"{{{ele}}}" for ele in const_args]
|
||||
const_args = ",".join(const_args)
|
||||
|
||||
lanuch_grid = list(self.grid)
|
||||
for i in range(len(lanuch_grid)):
|
||||
ele = lanuch_grid[i]
|
||||
if isinstance(ele, str):
|
||||
keys = list(const_hint_dict.keys())
|
||||
keys.sort(key=len, reverse=True)
|
||||
for key in keys:
|
||||
if key in ele:
|
||||
ele = ele.replace(key, f"{const_hint_dict[key]}")
|
||||
else:
|
||||
ele = str(ele)
|
||||
lanuch_grid[i] = ele
|
||||
|
||||
if len(lanuch_grid) < 3:
|
||||
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
|
||||
lanuch_grid = ",".join(lanuch_grid)
|
||||
|
||||
op_dict = {"op_name": op_name}
|
||||
op_dict["triton_kernel_args"] = ",".join(
|
||||
passed_arg_exclude_constexpr)
|
||||
op_dict["tensor_and_attr"] = ",".join(
|
||||
decalare_arg_exclude_constexpr)
|
||||
|
||||
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
|
||||
so_path = find_so_path(generated_dir, python_package_name)
|
||||
|
||||
if so_path is None:
|
||||
print("== we do not find so_path, we need to compile it")
|
||||
with open(paddle_custom_op_file_path, "w") as f:
|
||||
f.write(SubstituteTemplate(
|
||||
common_template,
|
||||
op_dict,
|
||||
))
|
||||
f.close()
|
||||
|
||||
# ahead of time compile command.
|
||||
aot_template = (
|
||||
f"""{python_path} {compile_file} {py_script_file} """ +
|
||||
f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
|
||||
+ f"""--out-name {op_name}_kernel """ +
|
||||
""" -w {num_warps} -ns {num_stages} """ +
|
||||
f""" -s"{address_hint} {value_hint} {const_args}" """ +
|
||||
f""" -g "{lanuch_grid}" """)
|
||||
|
||||
all_tune_config = [const_hint_dict]
|
||||
# reset const_hint_dict as empty.
|
||||
const_hint_dict = {}
|
||||
codegen_commands = []
|
||||
for config in all_tune_config:
|
||||
for key in const_hint_dict.keys():
|
||||
if const_hint_dict[key] is not None:
|
||||
if key not in config.keys():
|
||||
config[key] = const_hint_dict[key]
|
||||
else:
|
||||
if config[key] == const_hint_dict[key]:
|
||||
pass
|
||||
else:
|
||||
message = (
|
||||
f"you specify {key} both in arguments and config, "
|
||||
"and they are not same, this is wrong."
|
||||
)
|
||||
raise ValueError(message)
|
||||
else:
|
||||
assert key in config.keys(
|
||||
), f"you must specify {key} in your config."
|
||||
if "num_warps" not in config.keys():
|
||||
config["num_warps"] = 4
|
||||
if "num_stages" not in config.keys():
|
||||
config["num_stages"] = 4
|
||||
|
||||
for key in config:
|
||||
assert config[
|
||||
key] is not None, f"{key} must be specified."
|
||||
codegen_command = aot_template.format(**config, )
|
||||
print(codegen_command)
|
||||
codegen_commands.append(codegen_command)
|
||||
multi_process_do(codegen_commands)
|
||||
|
||||
link_command = (
|
||||
f"{python_path} {link_file} "
|
||||
f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel")
|
||||
re = os.system(link_command)
|
||||
assert re == 0
|
||||
|
||||
# rename the .c file to .cu
|
||||
rename_c_to_cu(generated_dir)
|
||||
# build the package to so, not install
|
||||
build_package(generated_dir, python_package_name)
|
||||
|
||||
# so_path have be found!
|
||||
so_path = find_so_path(generated_dir, python_package_name)
|
||||
print("== we find so_path: ", so_path)
|
||||
assert so_path is not None
|
||||
dir_path = os.path.dirname(so_path)
|
||||
sys.path.append(dir_path)
|
||||
lib = importlib.import_module(python_package_name)
|
||||
pybind_func = getattr(lib, f"{op_name}_func")
|
||||
self.func_map[op_name] = pybind_func
|
||||
|
||||
# run this op!
|
||||
self.func_map[op_name](*args)
|
||||
|
||||
self.decorator = decorator
|
||||
|
||||
def __getitem__(self, op_name_and_grid):
|
||||
"""
|
||||
override the operator [], which will call the decorator function.
|
||||
Args:
|
||||
op_name_and_grid: the name of the operator and the grid size.
|
||||
Returns:
|
||||
the decorator function.
|
||||
"""
|
||||
self.grid = ((
|
||||
"((max_possible_num_post_padded + BLOCK_SIZE_M -1)/ BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N-1) / BLOCK_SIZE_N)"
|
||||
), )
|
||||
|
||||
return self.decorator
|
||||
|
||||
|
||||
def paddle_use_triton_v2(other_config={}, key=[]):
|
||||
"""
|
||||
The decorator function that wraps the original function.
|
||||
Args:
|
||||
func: the original function.
|
||||
Returns:
|
||||
the wrapped function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
The decorator function that wraps the original function.
|
||||
Args:
|
||||
func: the original function.
|
||||
Returns:
|
||||
the wrapped function.
|
||||
"""
|
||||
return KernelInterface(func, other_config, key)
|
||||
|
||||
return decorator
|
@@ -17,6 +17,7 @@ from typing import Dict, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.config import SpeculativeConfig
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||
@@ -24,10 +25,11 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
speculate_get_padding_offset, speculate_get_seq_lens_output,
|
||||
speculate_save_output, speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
|
||||
step_paddle, step_system_cache, update_inputs)
|
||||
step_paddle, step_system_cache, update_inputs, step_reschedule)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import ModelOutputData
|
||||
|
||||
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
|
||||
|
||||
def pre_process(
|
||||
max_len: int,
|
||||
@@ -214,6 +216,8 @@ def step_cuda(
|
||||
"""
|
||||
TODO(gongshaotian): normalization name
|
||||
"""
|
||||
|
||||
|
||||
if speculative_config.method is not None:
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
@@ -291,6 +295,33 @@ def step_cuda(
|
||||
share_inputs["input_ids"], share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"], share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"], block_size, enc_dec_block_num)
|
||||
elif DISABLE_RECOVER:
|
||||
step_reschedule(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
)
|
||||
else:
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user