mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
Compare commits
49 Commits
release/2.
...
v2.0.2
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e421d51001 | ||
![]() |
c71d955e9c | ||
![]() |
2d2468ae72 | ||
![]() |
7deac64233 | ||
![]() |
5a5f17cf97 | ||
![]() |
0d61c65de1 | ||
![]() |
e5de28bff2 | ||
![]() |
b9eede57b6 | ||
![]() |
94e1a895e3 | ||
![]() |
87203ec87b | ||
![]() |
4596dd7248 | ||
![]() |
ec986642df | ||
![]() |
94691bcd90 | ||
![]() |
4025ea7e5b | ||
![]() |
e681e1e719 | ||
![]() |
823a47e64a | ||
![]() |
39d2a1de46 | ||
![]() |
1107e08cd9 | ||
![]() |
1fe37cb7e8 | ||
![]() |
337d76f094 | ||
![]() |
ae2f78184d | ||
![]() |
6851489425 | ||
![]() |
ea787d8f62 | ||
![]() |
90ef28d982 | ||
![]() |
b37585e693 | ||
![]() |
9cb08e71e8 | ||
![]() |
dacc46f04c | ||
![]() |
09ded7715f | ||
![]() |
11cfdf5d89 | ||
![]() |
e7fa57ebae | ||
![]() |
a5ae88ded9 | ||
![]() |
87e638498c | ||
![]() |
667547be59 | ||
![]() |
b38823bc66 | ||
![]() |
050d9658a5 | ||
![]() |
be5cabaf80 | ||
![]() |
240bdac2a4 | ||
![]() |
00863c43fd | ||
![]() |
3d3bccdf79 | ||
![]() |
9fd74f75bd | ||
![]() |
05c670e593 | ||
![]() |
d222248d00 | ||
![]() |
e5b94d4117 | ||
![]() |
87e2e58a22 | ||
![]() |
de20e5a992 | ||
![]() |
2f9c0618f0 | ||
![]() |
9a14ab6572 | ||
![]() |
d1cb3ed571 | ||
![]() |
b8a8a19689 |
83
.github/workflows/ci_xpu.yml
vendored
Normal file
83
.github/workflows/ci_xpu.yml
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
name: CI_XPU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-xpu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, XPU-P800-8Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd3:/ssd3" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_xpu.sh
|
||||
"
|
6
.github/workflows/gh-pages.yml
vendored
6
.github/workflows/gh-pages.yml
vendored
@@ -3,8 +3,6 @@ name: Deploy GitHub Pages
|
||||
on:
|
||||
push:
|
||||
branches: [ develop ]
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -21,4 +19,6 @@ jobs:
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: mkdocs gh-deploy --force --remote-name origin
|
||||
run: |
|
||||
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
|
||||
mkdocs gh-deploy --force --remote-name origin
|
||||
|
@@ -5,12 +5,6 @@ default_stages:
|
||||
- pre-commit # Run locally
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
@@ -29,15 +23,6 @@ repos:
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
|
1180
benchmarks/quick_benchmark.py
Normal file
1180
benchmarks/quick_benchmark.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,3 +3,4 @@ tqdm
|
||||
numpy
|
||||
Pillow
|
||||
pyyaml
|
||||
requests
|
||||
|
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
metadata:
|
||||
min_tokens: 32
|
||||
max_tokens: 33
|
19
build.sh
19
build.sh
@@ -166,7 +166,7 @@ function build_and_install() {
|
||||
|
||||
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
|
||||
cd $DIST_DIR
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
|
||||
if [ $? -ne 0 ]; then
|
||||
cd ..
|
||||
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
|
||||
@@ -176,6 +176,21 @@ function build_and_install() {
|
||||
cd ..
|
||||
}
|
||||
|
||||
function version_info() {
|
||||
output_file="fastdeploy/version.txt"
|
||||
fastdeploy_git_commit_id=$(git rev-parse HEAD)
|
||||
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
|
||||
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
|
||||
|
||||
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
|
||||
echo "Paddle version: $paddle_version" >> $output_file
|
||||
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
|
||||
echo "CUDA version: $cuda_version" >> $output_file
|
||||
echo "CXX compiler version: $cxx_version" >> $output_file
|
||||
}
|
||||
|
||||
function cleanup() {
|
||||
rm -rf $BUILD_DIR $EGG_DIR
|
||||
if [ `${python} -m pip list | grep fastdeploy | wc -l` -gt 0 ]; then
|
||||
@@ -207,6 +222,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
set -e
|
||||
|
||||
init
|
||||
version_info
|
||||
build_and_install_ops
|
||||
build_and_install
|
||||
cleanup
|
||||
@@ -237,6 +253,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
else
|
||||
init
|
||||
build_and_install_ops
|
||||
version_info
|
||||
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
fi
|
||||
|
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "multi_head_latent_attention_kernel.h"
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
T m;
|
||||
T d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_t() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
T other_m,
|
||||
T other_d) {
|
||||
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
|
||||
T m_prev = m, d_prev = d;
|
||||
m = m_prev > other_m ? m_prev : other_m;
|
||||
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
|
||||
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1 + other_o[i] * scale2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
|
||||
struct softmax_state_ts {
|
||||
uint32_t num_tiles_ = num_tiles;
|
||||
AlignedVector<T, vec_size> o[num_tiles];
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_ts() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; i++) {
|
||||
o[tile_id][i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
|
||||
__device__ __forceinline__ void produce_kv(CacheT *smem,
|
||||
CacheT *kv_base_gptr,
|
||||
const int * block_table_smem,
|
||||
const uint32_t seq_offset_gmem,
|
||||
const uint32_t seq_offset_smem,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t tidx,
|
||||
const uint32_t chunk_start,
|
||||
const uint32_t chunk_end) {
|
||||
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
|
||||
if (block_id < 0) {
|
||||
block_id = 0;
|
||||
}
|
||||
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
|
||||
// 8/16 T/int8 each time
|
||||
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
|
||||
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
|
||||
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
|
||||
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
|
||||
seq_offset_gmem < chunk_end
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
|
||||
const CacheT* k_smem,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
const uint32_t gid,
|
||||
const float scale,
|
||||
float *s,
|
||||
softmax_state_ts<vec_size, T, num_tile_v>& st) {
|
||||
const CacheT* smem;
|
||||
AlignedVector<T, vec_size> q_vec;
|
||||
AlignedVector<T, vec_size> k_vec;
|
||||
float m_prev = st.m;
|
||||
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
|
||||
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
if (iter_base + j < iter_bound) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = 0.f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = 0.f;
|
||||
}
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
|
||||
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
|
||||
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
st.m = st.m > s[j] ? st.m : s[j];
|
||||
}
|
||||
|
||||
// T o_scale = hexp(m_prev - st.m);
|
||||
float o_scale = __expf(m_prev - st.m);
|
||||
st.d *= o_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
// s[j] = hexp(s[j] - st.m);
|
||||
s[j] = __expf(s[j] - st.m);
|
||||
st.d += s[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
st.o[tile_id][i] *= o_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_sv(const float *s,
|
||||
const CacheT *base_v_smem,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
softmax_state_ts<vec_size, T, num_tile>& st) {
|
||||
const CacheT* v_smem;
|
||||
AlignedVector<T, vec_size> v_vec;
|
||||
#pragma unroll
|
||||
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
|
||||
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
|
||||
uint32_t tile_id = vid / bdx;
|
||||
#pragma unroll
|
||||
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
|
||||
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
@@ -0,0 +1,560 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decode_attention_func.cuh"
|
||||
|
||||
#define CHECK(call) \
|
||||
do \
|
||||
{ \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) \
|
||||
{ \
|
||||
printf("CUDA Error:\n"); \
|
||||
printf(" File: %s\n", __FILE__); \
|
||||
printf(" Line %d:\n", __LINE__); \
|
||||
printf(" Error code:%d\n", error_code); \
|
||||
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
}while(0)
|
||||
|
||||
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
|
||||
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const float in_scale,
|
||||
const int num_chunks,
|
||||
const int chunk_size,
|
||||
const int max_seq_len,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int qid = blockIdx.x, hid = blockIdx.y;
|
||||
const int seq_len_q = seq_lens_q[qid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[qid];
|
||||
if (seq_len_kv == 0) return;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
|
||||
return;
|
||||
}
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
T m;
|
||||
T d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
// merge per ty
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
T m_prev = m;
|
||||
T d_prev = d;
|
||||
const T m_now = multi_m[offset];
|
||||
const T d_now = multi_d[offset];
|
||||
m = m_prev > m_now ? m_prev : m_now;
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma once
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
|
||||
// merge bdy
|
||||
softmax_state_t<vec_size, T> st{};
|
||||
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
|
||||
#pragma once
|
||||
for (int i = 0; i < iter_num; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
AlignedVector<OutT, vec_size> out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = static_cast<OutT>(st.o[i]);
|
||||
}
|
||||
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
|
||||
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
|
||||
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
|
||||
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||
CacheT * __restrict__ cache_v,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
|
||||
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
|
||||
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
|
||||
OutT * __restrict__ out) {
|
||||
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t bid = bidx, gid = threadIdx.y;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
|
||||
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
|
||||
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
|
||||
|
||||
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
|
||||
|
||||
const int *block_table_now = block_table + bid * max_block_num_per_seq;
|
||||
|
||||
const uint32_t num_chunks = gridDim.y;
|
||||
const uint32_t chunk_id = blockIdx.y;
|
||||
const uint32_t q_len = seq_lens_q[bid];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
|
||||
if (kv_len <= 0) {
|
||||
return;
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
|
||||
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
|
||||
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
|
||||
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
using VecT = AlignedVector<T, VEC_SIZE>;
|
||||
VecT q_vec;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
|
||||
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
|
||||
q_vec[i] *= scale;
|
||||
}
|
||||
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
|
||||
}
|
||||
|
||||
|
||||
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
|
||||
uint32_t stage_idx = 0;
|
||||
constexpr int loop_times = DEAL_EACH_TIME / bdy;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
|
||||
const uint32_t k_seq_id = chunk_start + k_seq_offset;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
k_seq_id,
|
||||
k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
|
||||
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
|
||||
float s[DEAL_EACH_TIME];
|
||||
|
||||
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
wait_group<NUM_STAGES - 1>();
|
||||
__syncthreads();
|
||||
// compute qk
|
||||
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
|
||||
cu_q_smem,
|
||||
kv_smem,
|
||||
chunk_start + iter * DEAL_EACH_TIME,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
gid,
|
||||
scale,
|
||||
s,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
// compute sv
|
||||
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
|
||||
s,
|
||||
kv_smem,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = j * bdy + gid;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
|
||||
stage_idx * DEAL_EACH_TIME + k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// normize if not partition_kv
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
|
||||
const uint32_t tile_id = vid / bdx;
|
||||
if (!partition_kv || num_chunk_this_seq == 1) {
|
||||
st.normalize(tile_id);
|
||||
}
|
||||
if (partition_kv && num_chunk_this_seq > 1) {
|
||||
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
|
||||
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
tmp_m[head_idx] = st.m;
|
||||
tmp_d[head_idx] = st.d;
|
||||
} else {
|
||||
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
||||
void MultiQueryDecoderAttention(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
cudaStream_t &stream,
|
||||
const paddle::Tensor &q,
|
||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float rope_scale,
|
||||
const float rope_theta,
|
||||
const float softmax_scale,
|
||||
const float in_scale,
|
||||
paddle::Tensor *out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
constexpr int num_stages = NUM_STAGE;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
|
||||
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
|
||||
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
|
||||
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
|
||||
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
|
||||
|
||||
constexpr int blocky = GROUP_SIZE;
|
||||
const int gridx = bsz;
|
||||
|
||||
constexpr int num_threads = blockx * blocky;
|
||||
|
||||
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
|
||||
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
|
||||
uint32_t cache_smem_bytes = 0;
|
||||
|
||||
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
|
||||
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
|
||||
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
|
||||
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
|
||||
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
|
||||
|
||||
dim3 grids(gridx, num_chunks, kv_num_heads);
|
||||
dim3 blocks(blockx, blocky);
|
||||
if (num_chunks <= 1) {
|
||||
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
|
||||
cache_vec_size, blockx, blocky>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
} else {
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
|
||||
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
|
||||
constexpr int mblockx = HEAD_DIM_V / vec_size;
|
||||
constexpr int bdy = 256 / mblockx;
|
||||
dim3 grids_merge(bsz, num_heads);
|
||||
dim3 blocks_merge(mblockx, bdy);
|
||||
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
in_scale,
|
||||
num_chunks,
|
||||
chunk_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
}
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim_qk = meta_data.head_dims;
|
||||
const auto head_dim_v = meta_data.head_dims_v;
|
||||
const float rope_scale = 0.0;
|
||||
const float rope_theta = 0.0;
|
||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||
|
||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "mla_cache_kernel.cuh"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
|
||||
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
} else {
|
||||
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_OP(decode_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
|
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
@@ -0,0 +1,242 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
|
||||
int kv_num_heads;
|
||||
int token_nums;
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
};
|
||||
|
||||
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 576: { \
|
||||
constexpr size_t HEAD_DIM = 576; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_stage: ", num_stage); \
|
||||
}
|
||||
|
||||
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
||||
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
||||
constexpr size_t cache_bytes = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
if (block_shape_q <= 16) { \
|
||||
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
||||
|
@@ -316,6 +316,96 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
|
||||
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts);
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch);
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox);
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
|
||||
@@ -370,6 +460,14 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
|
||||
paddle::Tensor const &input,
|
||||
paddle::Tensor &scales, float scale_ub);
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(
|
||||
paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -627,6 +725,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"),
|
||||
py::arg("is_zp_float"));
|
||||
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
||||
"get_position_ids_and_mask_encoder_batch function");
|
||||
|
||||
|
||||
/**
|
||||
@@ -653,4 +753,13 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||
"dynamic_per_token_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
||||
|
||||
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
||||
|
||||
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
|
||||
|
||||
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
||||
|
||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||
}
|
||||
|
64
custom_ops/gpu_ops/env.h
Normal file
64
custom_ops/gpu_ops/env.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
inline uint32_t get_decoder_block_shape_q() {
|
||||
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
|
||||
static const uint32_t decoder_block_shape_q =
|
||||
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
|
||||
return decoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_encoder_block_shape_q() {
|
||||
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
|
||||
static const uint32_t encoder_block_shape_q =
|
||||
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
|
||||
return encoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_max_partition_size(int bsz) {
|
||||
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
|
||||
static const uint32_t max_partition_size =
|
||||
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
|
||||
return max_partition_size;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_deal_each_time() {
|
||||
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
|
||||
static const uint32_t cascade_attention_deal_each_time =
|
||||
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
|
||||
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_stages() {
|
||||
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
|
||||
static const uint32_t cascade_attention_num_stages =
|
||||
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
|
||||
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_threads() {
|
||||
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
|
||||
static const uint32_t cascade_attention_num_threads =
|
||||
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
|
||||
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
|
||||
}
|
||||
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding_kernel(
|
||||
T* __restrict__ arr,
|
||||
const T* __restrict__ cos_ptr,
|
||||
const T* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim) {
|
||||
int x_index, y_index;
|
||||
T cos, sin;
|
||||
if (IS_NEOX) {
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = cos_ptr[x_index];
|
||||
sin = sin_ptr[x_index];
|
||||
} else {
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = cos_ptr[x_index / 2];
|
||||
sin = sin_ptr[x_index / 2];
|
||||
}
|
||||
|
||||
const T x = arr[x_index];
|
||||
const T y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
__global__ void apply_rotary_embedding_kernel(
|
||||
T* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const int* __restrict__ position_ids, // [num_tokens]
|
||||
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int pos = position_ids[token_idx];
|
||||
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const T* cos_ptr = cache_ptr;
|
||||
const T* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.dims()[0];
|
||||
int num_heads = query.numel() / num_tokens / head_size;
|
||||
int num_kv_heads = key.numel() / num_tokens / head_size;
|
||||
int rot_dim = cos_sin_cache.dims()[1];
|
||||
int64_t query_stride = num_heads * head_size;
|
||||
int64_t key_stride = num_kv_heads * head_size;
|
||||
|
||||
if (num_tokens > 65535) {
|
||||
PD_THROW(
|
||||
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
||||
query.dtype(), "apply_rotary_embedding_kernel", [&] {
|
||||
if (is_neox) {
|
||||
apply_rotary_embedding_kernel<data_t, true>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
apply_rotary_embedding_kernel<data_t, false>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
PD_BUILD_OP(fused_rotary_position_encoding)
|
||||
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
||||
.Outputs({"query_out", "key_out"})
|
||||
.Attrs({"head_size: int", "is_neox: bool"})
|
||||
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
|
||||
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));
|
@@ -24,16 +24,18 @@
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define K 10
|
||||
#define K 20
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
|
||||
float mtext_f[MAX_BSZ * (K + 1)]; // score
|
||||
int mtext_ranks[MAX_BSZ]; // ranks
|
||||
};
|
||||
|
||||
void GetOutputTopK(const paddle::Tensor& x,
|
||||
const paddle::Tensor& scores,
|
||||
const paddle::Tensor& ranks,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
|
||||
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
|
||||
float* scores_data = const_cast<float*>(scores.data<float>());
|
||||
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
|
||||
scores_data[offset] = msg_rcv.mtext_f[offset];
|
||||
}
|
||||
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_output_topk)
|
||||
.Inputs({"x", "scores"})
|
||||
.Inputs({"x", "scores", "ranks"})
|
||||
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
|
||||
.Outputs({"x_out", "scores_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
|
||||
.Outputs({"x_out", "scores_out", "ranks_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputTopK));
|
||||
|
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
||||
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
|
||||
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
||||
const int* seq_lens_this_time,
|
||||
int* position_ids, // 输出的一维 position_ids
|
||||
int* mask_encoder_batch,
|
||||
const int bsz) { // 批次大小
|
||||
// 当前线程索引(每个线程对应一个批次)
|
||||
int tid = threadIdx.x;
|
||||
if (tid >= bsz) return;
|
||||
|
||||
// 动态计算当前批次的偏移量
|
||||
int offset = 0;
|
||||
for (int i = 0; i < tid; i++) {
|
||||
offset += seq_lens_encoder[i];
|
||||
if (seq_lens_decoder[i] > 0) {
|
||||
offset += seq_lens_this_time[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 当前批次的 encoder 和 decoder 长度
|
||||
int encoder_len = seq_lens_encoder[tid];
|
||||
int decoder_len = seq_lens_decoder[tid];
|
||||
int seq_len_this_time = seq_lens_this_time[tid];
|
||||
|
||||
// 写入 encoder 的 position_ids
|
||||
for (int i = 0; i < encoder_len; i++) {
|
||||
position_ids[offset + i] = i;
|
||||
mask_encoder_batch[offset + i] = 1;
|
||||
}
|
||||
offset += encoder_len;
|
||||
|
||||
// 写入 decoder 的 position_ids
|
||||
if (decoder_len > 0) {
|
||||
for (int i = 0; i < seq_len_this_time; i++) {
|
||||
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
|
||||
mask_encoder_batch[offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch) {
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
const_cast<int*>(position_ids.data<int>()),
|
||||
const_cast<int*>(mask_encoder_batch.data<int>()),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"position_ids",
|
||||
"mask_encoder_batch"})
|
||||
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
|
||||
.SetInplaceMap({{"position_ids", "position_ids_out"},
|
||||
{"mask_encoder_batch", "mask_encoder_batch_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
|
@@ -39,10 +39,12 @@ namespace cub = hipcub;
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "env.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -513,3 +515,10 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
inline int GetSMVersion() {
|
||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||
phi::backends::gpu::GetCurrentDeviceId());
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/detail/helper_macros.hpp>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||
};
|
||||
|
||||
template <int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Allreduce<2> {
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
|
||||
Tensor<Engine1, Layout1>& src, Operator& op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++) {
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
thread_reduce_<init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& max) {
|
||||
MaxOp<float> max_op;
|
||||
reduce_<init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
|
||||
typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& sum) {
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) {
|
||||
quad_allreduce_(sum, sum, sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max,
|
||||
const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// row_max * scale is a constant for each row, so we can use fma here
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
|
||||
struct OnlineSoftmax {
|
||||
constexpr static float fill_value = -5e4;
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
|
||||
TensorT row_max, row_sum, scores_scale;
|
||||
float sm_scale_log2;
|
||||
|
||||
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
|
||||
clear(scores_scale);
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
|
||||
|
||||
template <bool init, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
if constexpr (init) {
|
||||
reduce_max</*init=*/true>(scores, row_max);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
} else {
|
||||
// update row_max
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*init=*/false>(scores, row_max);
|
||||
// update scores_scale and scale row_sum
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
|
||||
} else {
|
||||
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
|
||||
}
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
// perform exp2 on scores
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
// update row_sum
|
||||
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
return scores_scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor0>
|
||||
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.f / sum;
|
||||
scores_scale(mi) = inv_sum;
|
||||
row_max(mi) *= sm_scale_log2;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template <typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor1, typename Tensor2>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
235
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
235
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "cute/tensor.hpp"
|
||||
#include "mla_hopper.cuh"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "batch_mla_with_paged_kv_cache.h"
|
||||
#include "env.h"
|
||||
|
||||
using namespace cute;
|
||||
using namespace mla_attn;
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
struct cascade_type_traits {
|
||||
using type = T;
|
||||
using cutlass_type = T;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::bfloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
using cutlass_type = cutlass::bfloat16_t;;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float16> {
|
||||
using type = half;
|
||||
using cutlass_type = cutlass::half_t;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
using cutlass_type = cutlass::float_e4m3_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
using NV_TYPE = typename cascade_type_traits<T>::type;
|
||||
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
O_tmp = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
|
||||
m_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
d_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
|
||||
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
|
||||
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
|
||||
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
|
||||
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
|
||||
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
|
||||
params.m = reinterpret_cast<float*>(m_tmp->ptr());
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
params.kv_stride_block_size = k_head_dim;
|
||||
params.o_stride_bsz = q_head_num * v_head_dim;
|
||||
params.o_stride_head_num = v_head_dim;
|
||||
params.bsz = bsz;
|
||||
params.token_num = token_num;
|
||||
params.max_seq_len = max_seq_len;
|
||||
params.max_block_num = max_block_num;
|
||||
params.max_block_num_per_seq = max_block_num_per_seq;
|
||||
params.q_num_head = q_head_num;
|
||||
params.qk_head_dim = q_head_dim;
|
||||
params.vo_head_dim = v_head_dim;
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
|
||||
params, stream
|
||||
);
|
||||
} else {
|
||||
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
|
||||
}
|
||||
}
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
69
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
69
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "append_attn/utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
|
||||
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
#define ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveEpilogue {
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
|
||||
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
|
||||
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
|
||||
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
|
||||
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
|
||||
|
||||
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideNTMAT = cute::Shape<int32_t, _1>;
|
||||
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
|
||||
|
||||
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
||||
using TMA_O = decltype(make_tma_copy(
|
||||
GmemTiledCopyOTMA{},
|
||||
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
|
||||
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
|
||||
|
||||
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
|
||||
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
|
||||
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
|
||||
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
|
||||
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
|
||||
using TiledCopyOValLayout =
|
||||
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
|
||||
using TiledCopyO =
|
||||
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
|
||||
TiledCopyOValLayout{} // Val layout
|
||||
));
|
||||
struct Arguments {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments_ntma(Arguments const& args) {
|
||||
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
|
||||
typename TiledMma>
|
||||
CUTLASS_DEVICE void store(Params const& epilogue_params,
|
||||
FrgTensorO const& tOrO,
|
||||
FrgTensorLSE const& lse,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int bsz,
|
||||
const int seq_len_now,
|
||||
const int start_token_idx,
|
||||
const int tile_idx,
|
||||
const int kv_len,
|
||||
const int chunk_size,
|
||||
const int max_draft_token_num,
|
||||
const int o_stride_bsz) {
|
||||
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
// make sure gemm done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
// r2s
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
// make sure r2s done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
TiledCopyO gmem_tiled_copy_O;
|
||||
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
|
||||
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
|
||||
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
|
||||
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
|
||||
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
|
||||
|
||||
// copy if not out of bound
|
||||
auto predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tOcOGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
|
||||
};
|
||||
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_
|
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
|
||||
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
|
||||
struct alignas(16) SharedStorageQKVO {
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
|
||||
union {
|
||||
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
|
||||
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
|
||||
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
|
||||
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
|
||||
struct AttentionKernelTraits {
|
||||
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
using DTypeQKAccum = float;
|
||||
using DTypePVAccum = float;
|
||||
using NV_TYPE = NV_TYPE_;
|
||||
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
|
||||
static constexpr int GROUP_SIZE = GROUP_SIZE_;
|
||||
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
|
||||
static_assert(BLOCK_SHAPE_Q % 64 == 0);
|
||||
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
|
||||
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
|
||||
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
|
||||
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
|
||||
static_assert(HEAD_DIM_QK % 32 == 0);
|
||||
static_assert(HEAD_DIM_VO % 32 == 0);
|
||||
|
||||
static constexpr int NUM_WARPS = 12;
|
||||
static constexpr int NUM_THREADS = 384;
|
||||
static constexpr int NUM_PRODUCER_THREADS = 128;
|
||||
|
||||
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_STAGES = NUM_STAGES_;
|
||||
|
||||
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
|
||||
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
|
||||
using TiledMmaQK = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
|
||||
using TiledMmaPV = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
|
||||
using SmemLayoutVt = decltype(composition(
|
||||
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
|
||||
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomV{},
|
||||
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
|
||||
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVtOneStage = decltype(composition(
|
||||
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
|
||||
get<2>(TileShape_PDV{}), Int<1>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
|
||||
|
||||
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
|
||||
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
|
||||
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
|
||||
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
|
||||
|
||||
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<1>(TileShape_QKD{}))>());
|
||||
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
|
||||
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
|
||||
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
|
||||
|
||||
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<1>;
|
||||
using MainloopPipeline =
|
||||
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
|
||||
typename cutlass::PipelineAsync<NUM_STAGES>>;
|
||||
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
|
||||
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif
|
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
@@ -0,0 +1,348 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits, bool CAUSAL>
|
||||
struct CollectiveMainloop {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = float;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
|
||||
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
|
||||
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
|
||||
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
|
||||
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
|
||||
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
|
||||
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
|
||||
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomQ = Layout<
|
||||
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyQ = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtomQ{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
|
||||
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
|
||||
using ShapeQT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideQT = cute::Shape<int32_t, _1>;
|
||||
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeMDT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideMDT = cute::Shape<int32_t, _1>;
|
||||
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)), StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_QKD{}),
|
||||
_1{})); // no mcast for KV
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesQ =
|
||||
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesKV =
|
||||
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ const* Q_ptr;
|
||||
DTypeKV const* KV_ptr;
|
||||
DTypeMD const* m_ptr;
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ *Q_ptr;
|
||||
DTypeKV* KV_ptr;
|
||||
DTypeMD* m_ptr;
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args) {
|
||||
TMA_KV tma_load_KV;
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
|
||||
tma_load_KV =
|
||||
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
|
||||
}
|
||||
return {args.layout_Q,
|
||||
args.layout_KV,
|
||||
args.layout_MD,
|
||||
const_cast<DTypeQ*>(args.Q_ptr),
|
||||
const_cast<DTypeKV*>(args.KV_ptr),
|
||||
const_cast<DTypeMD*>(args.m_ptr),
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
args.max_block_num_per_seq,
|
||||
args.q_stride_bsz,
|
||||
args.q_stride_head_num,
|
||||
args.kv_stride_block_num,
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& shared_storage,
|
||||
const int thread_idx,
|
||||
const int bid) {
|
||||
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
|
||||
Tensor gQ =
|
||||
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor cQ = cute::make_identity_tensor(gQ.shape());
|
||||
|
||||
GmemTiledCopyQ gmem_tiled_copy_q;
|
||||
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
|
||||
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
|
||||
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
|
||||
Tensor tQcQGroup = flatten_1(tQcQ);
|
||||
|
||||
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
|
||||
auto q_predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tQcQGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
|
||||
};
|
||||
Tensor tQgQiGroup = flatten_1(tQgQ);
|
||||
Tensor tQsQiGroup = flatten_1(tQsQ);
|
||||
|
||||
pipeline_q.producer_acquire(smem_pipe_write_q);
|
||||
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
|
||||
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_q;
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
GmemTiledCopy gmem_tiled_copy_kv;
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
|
||||
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
|
||||
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
|
||||
Tensor tKsKiGroup =
|
||||
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
|
||||
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
|
||||
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
|
||||
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
|
||||
|
||||
// Prepare the TMA loads
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
auto [tKgK, tKsK] =
|
||||
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
|
||||
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_
|
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include "named_barrier.cuh"
|
||||
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
|
||||
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
bool is_first_step = true;
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// gemm qk
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
|
||||
is_first_step = false;
|
||||
|
||||
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure r2s all done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
|
||||
// pv gemm
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
|
||||
// normalize
|
||||
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
// wait k
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// first qk gemm
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
{
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
|
||||
--kv_tile_idx;
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
|
||||
++smem_pipe_read_kv;
|
||||
// wait next kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
|
||||
// gemm next qk
|
||||
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
attention_updater.rescale_o(tOrO);
|
||||
// last pv gemm
|
||||
if (smem_pipe_read_kv_cur.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
// wait cur qk gemm
|
||||
warpgroup_wait<1>();
|
||||
// mask p
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure tSrS r2s done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// wait last pv gemm
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
// compute last pv
|
||||
attention_updater.rescale_o(tOrO);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
scale_o = attention_updater.finalize(tSrS);
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
attention_updater.rescale_o(tOrO);
|
||||
}
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
@@ -0,0 +1,575 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "attention_updater.cuh"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "epilogue.cuh"
|
||||
#include "helper.h"
|
||||
#include "kernel_traits.cuh"
|
||||
#include "mainloop_mma.cuh"
|
||||
#include "mainloop_load.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
||||
struct Params {
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *padding_offsets;
|
||||
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
|
||||
uint32_t kv_stride_block_num;
|
||||
uint32_t kv_stride_block_size;
|
||||
|
||||
uint32_t o_stride_bsz;
|
||||
uint32_t o_stride_head_num;
|
||||
|
||||
int bsz;
|
||||
int token_num;
|
||||
int max_seq_len;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_num_head;
|
||||
int qk_head_dim;
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
|
||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr size_t GROUP_SIZE = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
||||
} else {
|
||||
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
|
||||
}
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||
|
||||
|
||||
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
||||
MainloopPipeline pipeline_kv = [&] {
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
||||
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
||||
} else {
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
||||
}
|
||||
}();
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue;
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
// producer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<72>();
|
||||
}
|
||||
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// consumer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_alloc<216>();
|
||||
}
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineState smem_pipe_read_kv;
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
using DTypeKV = typename KernelTraits::DTypeKV;
|
||||
using DTypeO = typename KernelTraits::DTypeO;
|
||||
using IdType = typename KernelTraits::IdType;
|
||||
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
||||
|
||||
using CollectiveMainloop =
|
||||
CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
||||
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
||||
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
||||
params.Q,
|
||||
params.KV,
|
||||
params.m,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
params.max_block_num_per_seq,
|
||||
params.q_stride_bsz,
|
||||
params.q_stride_head_num,
|
||||
params.kv_stride_block_num,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||
params.O,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
||||
params.O_tmp,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||
});
|
||||
|
||||
// Get the ptr to kernel function.
|
||||
auto kernel =
|
||||
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
||||
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiprocessor_count;
|
||||
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
int act_blocks_per_sm;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
mainloop_params, epilogue_params
|
||||
);
|
||||
if (params.chunk_num > 1) {
|
||||
constexpr int vec_size = 16 / sizeof(DTypeO);
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.padding_offsets,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.max_seq_len,
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
/*BLOCK_SHAPE_Q_=*/64,
|
||||
/*BLOCK_SHAPE_KV_=*/64,
|
||||
/*NUM_STAGES_=*/2,
|
||||
typename Params::DTypeQ,
|
||||
typename Params::DTypeKV,
|
||||
typename Params::DTypeO,
|
||||
typename Params::IdType,
|
||||
NV_TYPE>,
|
||||
CAUSAL,
|
||||
Params,
|
||||
USE_REG_EALLOC,
|
||||
USE_FIXED_BLOCK>(params, stream);)
|
||||
} else {
|
||||
return cudaErrorNotSupported;
|
||||
}
|
||||
return cudaSuccess;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
enum class NamedBarriers {
|
||||
kQueryEmpty = 0,
|
||||
kValueEmpty = 1,
|
||||
kWarpSchedulerWG1 = 2,
|
||||
kWarpSchedulerWG2 = 3,
|
||||
kWarpSchedulerWG3 = 4,
|
||||
kPrefetchIndices = 5,
|
||||
kOdone = 6,
|
||||
kWG1WG2Sync = 7,
|
||||
kWG0WG1WG2Sync = 8,
|
||||
kWG1WG2LastSync = 9,
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
@@ -0,0 +1,351 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_UTILS_CUH_
|
||||
#define ATTENTION_HOPPER_UTILS_CUH_
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename TensorT>
|
||||
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
|
||||
Tensor tensor_flatten = cute::flatten(tensor);
|
||||
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
|
||||
int64_t h_stride) {
|
||||
return make_layout(make_shape(nnz, head_dim, num_heads),
|
||||
make_stride(n_stride, cute::_1{}, h_stride));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
|
||||
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(offset, _0{}));
|
||||
auto g_sequence =
|
||||
make_tensor(g_offset.data(),
|
||||
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
|
||||
|
||||
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
|
||||
cute::make_shape(shape<0>(m_tensor))));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
|
||||
// MMA_N))
|
||||
template <typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
|
||||
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
|
||||
// MMA_N))
|
||||
template <typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
|
||||
make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
|
||||
TensorC& tCrC) {
|
||||
constexpr bool Is_RS =
|
||||
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
warpgroup_arrive();
|
||||
if constexpr (init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
if constexpr (wg_wait >= 0) {
|
||||
warpgroup_wait<wg_wait>();
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
}
|
||||
|
||||
#define HOSTDEVICE __host__ __device__
|
||||
|
||||
template <typename T, int Size>
|
||||
struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
|
||||
const AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
|
||||
*vec = *addr_vec;
|
||||
}
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
|
||||
AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size>*>(addr);
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct prefill_softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
const float other_m,
|
||||
const float other_d) {
|
||||
float m_prev = m, d_prev = d;
|
||||
m = max(m_prev, other_m);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
const T d_t = static_cast<T>(d);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int * __restrict__ padding_offsets,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int max_seq_len,
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
const int max_draft_token_num) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
}
|
||||
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
float m;
|
||||
float d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
m = -3.0e+30f;
|
||||
}
|
||||
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset;
|
||||
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
|
||||
float m_prev = m;
|
||||
float d_prev = d;
|
||||
const float m_now = multi_m[offset];
|
||||
const float d_now = multi_d[offset];
|
||||
m = max(m_prev, m_now);
|
||||
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
if (ty == 0) {
|
||||
// merge bdy
|
||||
prefill_softmax_state_t<vec_size, T> st;
|
||||
st.init();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < bdy; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_UTILS_CUH_
|
@@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
469
custom_ops/gpu_ops/multi_head_latent_attention.cu
Normal file
469
custom_ops/gpu_ops/multi_head_latent_attention.cu
Normal file
@@ -0,0 +1,469 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "append_attn/multi_head_latent_attention_kernel.h"
|
||||
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
if ((speculate_decoder || mla_use_tensorcore) && sm_version < 90) {
|
||||
PD_THROW("Please use speculate_decoder=0 and FLAGS_mla_use_tensorcore=0 when sm < 90.");
|
||||
}
|
||||
|
||||
auto main_stream = query.stream();
|
||||
|
||||
paddle::Tensor fmha_out = paddle::full(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
|
||||
0,
|
||||
D,
|
||||
query.place());
|
||||
|
||||
if (max_dec_len_this_time_data > 0) {
|
||||
if (mla_use_tensorcore) {
|
||||
BatchMLAWithPagedKVCacheKernel<data_t>(meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_num_blocks_data,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
main_stream,
|
||||
&fmha_out);
|
||||
} else {
|
||||
DecodeMLAAttentionKernel<data_t>(
|
||||
meta_data,
|
||||
query, // [token_num, num_heads, head_dim]
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time, // q_seq_len is 1
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
out_linear_in_scale,
|
||||
causal,
|
||||
main_stream,
|
||||
&fmha_out);
|
||||
}
|
||||
}
|
||||
return {fmha_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& query_dims = query.dims();
|
||||
const auto& key_cache_dims = key_cache.dims();
|
||||
const int q_hidden_size = query_dims[query_dims.size() - 1];
|
||||
meta_data.token_nums = query_dims[0];
|
||||
meta_data.kv_num_heads = key_cache_dims[1];
|
||||
meta_data.head_dims = key_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
|
||||
switch (query.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return MultiHeadLatentAttentionKernel<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
query_bias,
|
||||
query_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
cache_quant_type_str,
|
||||
max_input_length,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return MultiHeadLatentAttentionKernel<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
query_bias,
|
||||
query_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
cache_quant_type_str,
|
||||
max_input_length,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& query_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& query_out_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
const int token_num = query_shape[0];
|
||||
const int kv_num_heads = key_cache_shape[1];
|
||||
const int head_dim_qk = key_cache_shape[3];
|
||||
const int head_dim_v = nope_size;
|
||||
const int q_hidden_size = query_shape[query_shape.size() - 1];
|
||||
const int num_heads = q_hidden_size / head_dim_qk;
|
||||
return {{token_num, num_heads * head_dim_v}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& query_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& query_out_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
if (compute_dtype == "bf16") {
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
} else if (compute_dtype == "fp16") {
|
||||
return {paddle::DataType::FLOAT16};
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("query_bias"),
|
||||
paddle::Optional("query_out_scales"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
paddle::Optional("cache_v_dequant_scales"),
|
||||
paddle::Optional("cache_k_zp"),
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths")})
|
||||
.Outputs({"fmha_out"})
|
||||
.Attrs({"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"nope_size: int",
|
||||
"max_input_length: int",
|
||||
"softmax_scale: float",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(MultiHeadLatentAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MultiHeadLatentAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MultiHeadLatentAttentionInferDtype));
|
73
custom_ops/gpu_ops/noaux_tc.cu
Normal file
73
custom_ops/gpu_ops/noaux_tc.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "helper.h"
|
||||
#include "noauxtc_kernel.h"
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
int64_t num_tokens = input_shape[0];
|
||||
int64_t num_experts = input_shape[1];
|
||||
auto input_type = scores_with_bias.dtype();
|
||||
auto place = scores_with_bias.place();
|
||||
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
|
||||
auto stream = scores_with_bias.stream();
|
||||
|
||||
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
|
||||
reinterpret_cast<float*>(group_scores.data<float>()),
|
||||
reinterpret_cast<float*>(scores_with_bias.data<float>()),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
return {scores};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> NoauxTcInferDtype(
|
||||
const paddle::DataType& scores_dtype,
|
||||
const paddle::DataType& scores_with_bias_dtype) {
|
||||
return {scores_dtype};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
||||
const std::vector<int64_t>& scores_shape,
|
||||
const std::vector<int64_t>& gating_output_shape) {
|
||||
return {scores_shape};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(noaux_tc)
|
||||
.Inputs({"scores", "scores_with_bias"})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcInferDtype));
|
551
custom_ops/gpu_ops/noauxtc_kernel.h
Normal file
551
custom_ops/gpu_ops/noauxtc_kernel.h
Normal file
@@ -0,0 +1,551 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// This code is partially inspired by and references the implementation found
|
||||
// in NVIDIA TRTLLM.
|
||||
#pragma once
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
return ((len - 1) / size + 1) * size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
return (v && !(v & (v - 1)));
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
constexpr int stride = arr_len / 2;
|
||||
for (int i = 0; i < stride; ++i) {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
|
||||
idxT tmp2 = idx_arr[i];
|
||||
idx_arr[i] = idx_arr[other_i];
|
||||
idx_arr[other_i] = tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// ascending doesn't matter before merging since all we need is a bitonic
|
||||
// sequence
|
||||
for (int stage = 0; stage < 4; ++stage) {
|
||||
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
||||
bool reverse = (lane >> stage) & 2;
|
||||
bool is_second = lane & stride;
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
bool is_second = lane & stride;
|
||||
T& val = *val_arr;
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSort {
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out[out_i] = val_arr_[i];
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
idxT idx_arr_[max_arr_len_];
|
||||
|
||||
int const lane_;
|
||||
idxT const k_;
|
||||
T const dummy_;
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
int const num_of_warp = blockDim.x / WARP_SIZE;
|
||||
int const warp_id = threadIdx.x / WARP_SIZE;
|
||||
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
||||
val_smem_ += warp_id * WARP_SIZE;
|
||||
idx_smem_ = reinterpret_cast<idxT*>(
|
||||
smem_buf +
|
||||
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
||||
idx_smem_ += warp_id * WARP_SIZE;
|
||||
}
|
||||
|
||||
__device__ void add(T const* in, idxT start, idxT end) {
|
||||
idxT const end_for_fullwarp =
|
||||
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
||||
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
||||
T val = (i < end) ? in[i] : dummy_;
|
||||
add(val, i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
||||
if (do_add && pos < WARP_SIZE) {
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
do_add = false;
|
||||
}
|
||||
smem_buf_len_ += __popc(mask);
|
||||
if (smem_buf_len_ >= WARP_SIZE) {
|
||||
__syncwarp();
|
||||
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
||||
smem_buf_len_ -= WARP_SIZE;
|
||||
}
|
||||
if (do_add) {
|
||||
pos -= WARP_SIZE;
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ void done() {
|
||||
if (smem_buf_len_) {
|
||||
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output,
|
||||
T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = input[i];
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
largest = value;
|
||||
} else if (value > second_largest) {
|
||||
second_largest = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
largest = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
T max2 = max1;
|
||||
bool equal_to_max1 = (max1 == largest);
|
||||
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
||||
|
||||
if (count_max1 == 1) {
|
||||
largest = (largest == max1) ? second_largest : largest;
|
||||
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*output = max1 + max2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void topk_with_k2_kernel(T* output,
|
||||
T* input,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores,
|
||||
T const* group_scores,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
if (case_id < num_tokens) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = i < topk ? scores[s_topk_idx[i]]
|
||||
: 0.0f; // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeNoAuxTc(T* scores,
|
||||
T* group_scores,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_experts,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T) \
|
||||
template void invokeNoAuxTc<T>(T * scores, \
|
||||
T * group_scores, \
|
||||
T * scores_with_bias, \
|
||||
int64_t const num_tokens, \
|
||||
int64_t const num_experts, \
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float);
|
@@ -50,11 +50,13 @@ __global__ void quant_per_token_per_block(const T *input,
|
||||
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
|
||||
}
|
||||
// get max value per warp
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
// broadcast max_value
|
||||
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
|
||||
max_value_thread = max(max_value_thread, epsilon);
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
// quant
|
||||
|
@@ -18,6 +18,7 @@
|
||||
|
||||
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_p,
|
||||
const paddle::optional<paddle::Tensor> &top_k,
|
||||
int seed) {
|
||||
std::vector<int64_t> probs_shape = probs.shape();
|
||||
unsigned int batch_size = probs_shape[0];
|
||||
@@ -40,10 +41,18 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
|
||||
cudaError_t status;
|
||||
|
||||
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
|
||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||
batch_size, top_p.data<float>(), vocab_size,
|
||||
true, philox_seed, philox_offset, cu_stream);
|
||||
if (top_k) {
|
||||
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
|
||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||
batch_size, top_p.data<float>(), top_k.get().data<int64_t>(), vocab_size,
|
||||
true, philox_seed, philox_offset, cu_stream);
|
||||
}
|
||||
else {
|
||||
status = sampling::TopPSamplingFromProb<float, int64_t>(
|
||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||
batch_size, top_p.data<float>(), vocab_size,
|
||||
true, philox_seed, philox_offset, cu_stream);
|
||||
}
|
||||
|
||||
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
@@ -53,19 +62,21 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
|
||||
const std::vector<int64_t> &top_p_shape) {
|
||||
const std::vector<int64_t> &top_p_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
|
||||
int64_t bs = probs_shape[0];
|
||||
return {{bs, 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
|
||||
const paddle::DataType &top_p_shape) {
|
||||
const paddle::DataType &top_p_dtype,
|
||||
const paddle::optional<paddle::DataType> &top_k_dtype) {
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(rejection_top_p_sampling)
|
||||
.Inputs({"probs", "top_p"})
|
||||
.Inputs({"probs", "top_p", paddle::Optional("top_k")})
|
||||
.Outputs({"samples"})
|
||||
.Attrs({"seed: int"})
|
||||
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
|
||||
|
@@ -279,7 +279,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
float* top_p_arr, IdType* top_k_arr,
|
||||
uint32_t d, uint64_t philox_seed,
|
||||
uint64_t philox_offset) {
|
||||
const uint32_t batch_size = gridDim.x;
|
||||
@@ -287,7 +288,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(philox_seed, bx, philox_offset, &state);
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
@@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
if (aggregate_gt_pivot_0 < top_p) {
|
||||
// case 1: pivot_0 accepted
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (aggregate_gt_pivot_1 < top_p) {
|
||||
// case 2: pivot_0 rejected, pivot_1 accepted
|
||||
low = pivot_0;
|
||||
@@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
typename TempStorage>
|
||||
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
|
||||
TempStorage& temp_storage) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
vec_t<float, VEC_SIZE> in_data_vec;
|
||||
|
||||
float max_val = 0;
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
in_data_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
float in_data_[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
in_data_[j] = in_data_vec[j];
|
||||
}
|
||||
max_val = max(
|
||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
__syncthreads();
|
||||
}
|
||||
if (tx == 0) {
|
||||
temp_storage.max_val = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
return temp_storage.max_val;
|
||||
}
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||
struct RenormTempStorage {
|
||||
union {
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_value_count;
|
||||
} block_prim;
|
||||
struct {
|
||||
float max_val;
|
||||
float min_val;
|
||||
union {
|
||||
struct {
|
||||
float values[2];
|
||||
};
|
||||
struct {
|
||||
int counts[2];
|
||||
};
|
||||
struct {
|
||||
ValueCount<float> pairs[2];
|
||||
};
|
||||
} block_aggregate;
|
||||
};
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
if (k < d) {
|
||||
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
||||
uint8_t smem_renorm[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
||||
temp_storage.max_val = 0;
|
||||
|
||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
|
||||
double low = 0, high = max_val;
|
||||
float min_gt_low, max_le_high;
|
||||
float sum_low = 1;
|
||||
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
||||
// loop invariant:
|
||||
// - f(low) >= k, f(high) < k
|
||||
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
||||
// stopping condition: min_gt_low == max_le_high
|
||||
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
|
||||
do {
|
||||
double pivot_0 = (high + 2 * low) / 3;
|
||||
double pivot_1 = (2 * high + low) / 3;
|
||||
|
||||
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
|
||||
min_gt_low = high;
|
||||
max_le_high = low;
|
||||
#pragma unroll 2
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_gt_pivot_0_pair[j] = {
|
||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1_pair[j] = {
|
||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
|
||||
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
min_gt_low = min(min_gt_low, probs_vec[j]);
|
||||
}
|
||||
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
max_le_high = max(max_le_high, probs_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||
__syncthreads();
|
||||
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||
__syncthreads();
|
||||
}
|
||||
min_gt_low =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
__syncthreads();
|
||||
max_le_high =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||
temp_storage.min_val = min_gt_low;
|
||||
temp_storage.max_val = max_le_high;
|
||||
}
|
||||
__syncthreads();
|
||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
|
||||
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
|
||||
min_gt_low = temp_storage.min_val;
|
||||
max_le_high = temp_storage.max_val;
|
||||
|
||||
if (aggregate_gt_pivot_1.count >= k) {
|
||||
low = pivot_1;
|
||||
sum_low = float(aggregate_gt_pivot_1.value);
|
||||
} else if (aggregate_gt_pivot_0.count >= k) {
|
||||
low = pivot_0;
|
||||
high = min(pivot_1, max_le_high);
|
||||
sum_low = float(aggregate_gt_pivot_0.value);
|
||||
} else {
|
||||
high = min(pivot_0, max_le_high);
|
||||
}
|
||||
} while (min_gt_low != max_le_high);
|
||||
|
||||
normalizer = ptx_rcp(max(sum_low, 1e-8));
|
||||
pivot = low;
|
||||
}
|
||||
|
||||
// normalize
|
||||
#pragma unroll 2
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
||||
}
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val,
|
||||
@@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val,
|
||||
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
||||
uint32_t d, bool deterministic,
|
||||
uint64_t philox_seed, uint64_t philox_offset,
|
||||
cudaStream_t stream = 0) {
|
||||
@@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &output, &top_p_val,
|
||||
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
|
||||
&d, &philox_seed, &philox_offset};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
@@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
|
||||
uint32_t batch_size, uint32_t d,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
|
||||
CUDA_CALL(
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
});
|
||||
return cudaSuccess;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
|
61
custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu
Normal file
61
custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
#include "sample_kernels/sampling.cuh"
|
||||
|
||||
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_k) {
|
||||
std::vector<int64_t> probs_shape = probs.shape();
|
||||
uint32_t batch_size = probs_shape[0];
|
||||
uint32_t vocab_size = probs_shape[1];
|
||||
auto cu_stream = probs.stream();
|
||||
|
||||
auto renorm_probs =
|
||||
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
|
||||
|
||||
cudaError_t status;
|
||||
|
||||
|
||||
status = sampling::TopKRenormProb<float>(
|
||||
const_cast<float *>(probs.data<float>()),
|
||||
renorm_probs.data<float>(),
|
||||
const_cast<int64_t *>(top_k.data<int64_t>()),
|
||||
batch_size, vocab_size, cu_stream);
|
||||
|
||||
PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
|
||||
return {renorm_probs};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
TopKRenormInferShape(const std::vector<int64_t> &probs_shape,
|
||||
const std::vector<int64_t> &top_k_shape) {
|
||||
return {probs_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
TopKRenormInferDtype(const paddle::DataType &probs_dtype,
|
||||
const paddle::DataType &top_k_shape) {
|
||||
return {probs_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(top_k_renorm_probs)
|
||||
.Inputs({"probs", "top_k"})
|
||||
.Outputs({"renorm_probs"})
|
||||
.SetKernelFn(PD_KERNEL(TopKRenorm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(TopKRenormInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(TopKRenormInferDtype));
|
@@ -23,34 +23,34 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 128
|
||||
#define K 10
|
||||
#define MAX_BSZ 512
|
||||
#define K 20
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
|
||||
float mtext_f[MAX_BSZ * (K + 1)]; // score
|
||||
int mtext_ranks[MAX_BSZ]; // ranks
|
||||
};
|
||||
|
||||
void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
const paddle::Tensor& scores,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_scores, // [bsz, k]
|
||||
const paddle::Tensor& logprob_token_ids, // [bsz, k+1]
|
||||
const paddle::Tensor& logprob_scores, // [bsz, k+1]
|
||||
const paddle::Tensor& ranks,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int k,
|
||||
int64_t rank_id) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_token_ids_cpu = logprob_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto ranks_cpu = ranks.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
float* scores_data = scores_cpu.data<float>();
|
||||
int64_t* topk_ids_data = topk_ids_cpu.data<int64_t>();
|
||||
float* topk_scores_data = topk_scores_cpu.data<float>();
|
||||
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
|
||||
float* logprob_scores_data = logprob_scores_cpu.data<float>();
|
||||
int64_t* ranks_data = ranks_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
int msg_queue_id = 1;
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
@@ -106,21 +106,23 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
for (int j = 0; j < k + 1; j++) {
|
||||
for (int j = 0; j < K + 1; j++) {
|
||||
const int64_t offset = i * (K + 1) + j;
|
||||
if (j == 0) {
|
||||
msg_sed.mtext[offset + 2] = (int)x_data[i];
|
||||
msg_sed.mtext_f[offset] = scores_data[i];
|
||||
} else if (j <= k + 1) {
|
||||
msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1];
|
||||
msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1];
|
||||
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
|
||||
} else if (j < max_num_logprobs) {
|
||||
msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[i * max_num_logprobs + j];
|
||||
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
|
||||
} else {
|
||||
msg_sed.mtext[offset + 2] = -1;
|
||||
msg_sed.mtext_f[offset] = 0.0;
|
||||
}
|
||||
}
|
||||
msg_sed.mtext_ranks[i] = (int)ranks_data[i];
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: ";
|
||||
@@ -131,7 +133,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4 + MAX_BSZ * 4,
|
||||
0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
@@ -139,8 +141,8 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(save_output_topk)
|
||||
.Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"})
|
||||
.Attrs({"k: int", "rank_id: int64_t"})
|
||||
.Inputs({"x", "topk_ids", "logprob_scores", "ranks", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SaveOutMmsgTopK));
|
||||
|
@@ -267,6 +267,10 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
]
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -376,6 +380,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
# append_attention
|
||||
sources += ["gpu_ops/append_attention.cu"]
|
||||
sources += find_end_files("gpu_ops/append_attn", ".cu")
|
||||
# mla
|
||||
sources += ["gpu_ops/multi_head_latent_attention.cu"]
|
||||
# gemm_dequant
|
||||
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
|
||||
# speculate_decoding
|
||||
@@ -441,6 +447,10 @@ elif paddle.is_compiled_with_cuda():
|
||||
|
||||
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
||||
|
||||
if cc >= 90 and nvcc_version >= 12.0:
|
||||
# Hopper optmized mla
|
||||
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
ext_modules=CUDAExtension(
|
||||
|
@@ -22,6 +22,7 @@ setup(
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
"gpu_ops/get_output_msg_with_topk.cc",
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"cpu_ops/rebuild_padding.cc",
|
||||
],
|
||||
|
@@ -9,11 +9,6 @@ COPY . /workspace/FastDeploy
|
||||
RUN echo "ulimit -u unlimited" >> /root/.bashrc
|
||||
RUN echo "ulimit -n 65536" >> /root/.bashrc
|
||||
|
||||
# setting proxy
|
||||
ARG http_proxy=agent.baidu.com:8891
|
||||
ARG https_proxy=agent.baidu.com:8891
|
||||
ARG no_proxy=localhost,bj.bcebos.com,su.bcebos.com,pypi.tuna.tsinghua.edu.cn,paddle-ci.gz.bcebos.com
|
||||
|
||||
# uninstall existing package
|
||||
RUN python -m pip uninstall paddlepaddle-gpu fastdeploy-gpu -y
|
||||
|
||||
|
@@ -7,11 +7,6 @@ deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe
|
||||
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse \n\
|
||||
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse" > /etc/apt/sources.list
|
||||
|
||||
# setting proxy
|
||||
ENV http_proxy=http://agent.baidu.com:8891
|
||||
ENV https_proxy=http://agent.baidu.com:8891
|
||||
ENV no_proxy=localhost,bj.bcebos.com,su.bcebos.com,pypi.tuna.tsinghua.edu.cn,paddle-ci.gz.bcebos.com
|
||||
|
||||
RUN apt-get update && apt-get install -y libibverbs-dev librdmacm-dev cmake pybind11-dev
|
||||
|
||||
# uninstall existing package
|
||||
@@ -40,4 +35,4 @@ COPY . /workspace/FastDeploy
|
||||
RUN cd /workspace/FastDeploy && bash build.sh && python -m pip install --no-cache-dir dist/* && rm -rf /workspace/FastDeploy
|
||||
|
||||
ENV http_proxy=""
|
||||
ENV https_proxy=""
|
||||
ENV https_proxy=""
|
||||
|
@@ -1,19 +1,19 @@
|
||||
# Chain-of-Thought Content
|
||||
# Reasoning Outputs
|
||||
|
||||
The reasoning model returns a `reasoning_content` field in the output, representing the chain-of-thought content—the reasoning steps that lead to the final conclusion.
|
||||
Reasoning models return an additional `reasoning_content` field in their output, which contains the reasoning steps that led to the final conclusion.
|
||||
|
||||
## Currently Supported Chain-of-Thought Models
|
||||
| Model Name | Parser Name | Chain-of-Thought Enabled by Default |
|
||||
|----------------|----------------|-------------------------------------|
|
||||
| ernie-45-vl | ernie-45-vl | ✓ |
|
||||
| ernie-lite-vl | ernie-45-vl | ✓ |
|
||||
## Supported Models
|
||||
| Model Name | Parser Name | Eable_thinking by Default |
|
||||
|----------------|----------------|---------------------------|
|
||||
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | ernie-45-vl | ✓ |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✓ |
|
||||
|
||||
The reasoning model requires a specified parser to interpret the reasoning content. The reasoning mode can be disabled by setting the `enable_thinking=False` parameter.
|
||||
The reasoning model requires a specified parser to extract reasoning content. The reasoning mode can be disabled by setting the `enable_thinking=False` parameter.
|
||||
|
||||
Interfaces that support toggling the reasoning mode:
|
||||
1. `/v1/chat/completions` request in OpenAI services.
|
||||
2. `/v1/chat/completions` request in the OpenAI Python client.
|
||||
3. `llm.chat` request in Offline interfaces.
|
||||
1. `/v1/chat/completions` requests in OpenAI services.
|
||||
2. `/v1/chat/completions` requests in the OpenAI Python client.
|
||||
3. `llm.chat` requests in Offline interfaces.
|
||||
|
||||
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
|
||||
|
||||
@@ -21,10 +21,15 @@ For reasoning models, the length of the reasoning content can be controlled via
|
||||
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
|
||||
This parser will process the model's output and extract the `reasoning_content` field.
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server --model /root/merge_llm_model --enable-mm --tensor-parallel-size=8 --port 8192 --quantization wint4 --reasoning-parser=ernie-45-vl
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/your/model \
|
||||
--enable-mm \
|
||||
--tensor-parallel-size 8 \
|
||||
--port 8192 \
|
||||
--quantization wint4 \
|
||||
--reasoning-parser ernie-45-vl
|
||||
```
|
||||
|
||||
Next, send a `chat completion` request to the model:
|
||||
Next, make a request to the model that should return the reasoning content in the response.
|
||||
```bash
|
||||
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -40,8 +45,8 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
||||
```
|
||||
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
|
||||
|
||||
### Streaming Sessions
|
||||
In streaming sessions, the `reasoning_content` field can be retrieved from the `delta` in `chat completion response chunks`.
|
||||
### Streaming chat completions
|
||||
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
|
||||
```python
|
||||
from openai import OpenAI
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
|
@@ -29,6 +29,7 @@ for output in outputs:
|
||||
```
|
||||
|
||||
### Chat Interface (LLM.chat)
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
@@ -99,6 +100,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
|
||||
* repetition_penalty(float): Direct penalty for repeated tokens (>1 penalizes, <1 encourages)
|
||||
* temperature(float): Controls randomness (higher = more random)
|
||||
* top_p(float): Probability threshold for token selection
|
||||
* top_k(int): Number of tokens considered for sampling
|
||||
* max_tokens(int): Maximum generated tokens (input + output)
|
||||
* min_tokens(int): Minimum forced generation length
|
||||
|
||||
@@ -129,4 +131,4 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
|
||||
* first_token_time(float): First token latency
|
||||
* time_in_queue(float): Queuing time
|
||||
* model_forward_time(float): Forward pass duration
|
||||
* model_execute_time(float): Total execution time (including preprocessing)
|
||||
* model_execute_time(float): Total execution time (including preprocessing)
|
||||
|
@@ -62,7 +62,7 @@ The differences in request parameters between FastDeploy and the OpenAI protocol
|
||||
- `stream_options`: Optional[StreamOptions] = None
|
||||
- `temperature`: Optional[float] = None
|
||||
- `top_p`: Optional[float] = None
|
||||
- `metadata`: Optional[dict] = None (supported only in `v1/chat/completions` for configuring additional parameters, e.g., `meta_data={"enable_thinking": True}`)
|
||||
- `metadata`: Optional[dict] = None (supported only in `v1/chat/completions` for configuring additional parameters, e.g., `metadata={"enable_thinking": True}`)
|
||||
- `min_tokens`: Optional[int] = 1 (minimum number of tokens generated)
|
||||
- `reasoning_max_tokens`: Optional[int] = None (maximum number of tokens for reasoning content, defaults to the same as `max_tokens`)
|
||||
- `enable_thinking`: Optional[bool] = True (whether to enable reasoning for models that support deep thinking)
|
||||
|
@@ -27,7 +27,7 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```kv_cache_ratio``` | `float` | KVCache blocks are divided between Prefill phase and Decode phase according to kv_cache_ratio ratio, default: 0.75 |
|
||||
| ```enable_prefix_caching``` | `bool` | Whether to enable Prefix Caching, default: False |
|
||||
| ```swap_space``` | `float` | When Prefix Caching is enabled, CPU memory size for KVCache swapping, unit: GB, default: None |
|
||||
| ```enable_chunk_prefill``` | `bool` | Enable Chunked Prefill, default: False |
|
||||
| ```enable_chunked_prefill``` | `bool` | Enable Chunked Prefill, default: False |
|
||||
| ```max_num_partial_prefills``` | `int` | When Chunked Prefill is enabled, maximum concurrent number of partial prefill batches, default: 1 |
|
||||
| ```max_long_partial_prefills``` | `int` | When Chunked Prefill is enabled, maximum number of long requests in concurrent partial prefill batches, default: 1 |
|
||||
| ```long_prefill_token_threshold``` | `int` | When Chunked Prefill is enabled, requests with token count exceeding this value are considered long requests, default: max_model_len*0.04 |
|
||||
@@ -115,5 +115,5 @@ FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter
|
||||
...
|
||||
```
|
||||
- When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1.
|
||||
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunk_prefill```.
|
||||
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```.
|
||||
- When ```use_cudagraph``` is enabled, batches with size ≤ ```max_capture_batch_size``` will be executed by CudaGraph, batches > ```max_capture_batch_size``` will be executed by original dynamic/static graph. To have all batch sizes executed by CudaGraph, ```max_capture_batch_size``` value should match ```max_num_seqs```. ```max_capture_batch_size``` > ```max_num_seqs``` will cause waste by capturing batches that won't be encountered during inference, occupying more time and memory.
|
@@ -57,3 +57,6 @@ On the ERNIE-4.5-300B-A47B model, comparison of WINT2 vs WINT4 performance:
|
||||
| IFEval |500|88.17 | 85.40 |
|
||||
|BBH|6511|94.43|92.02|
|
||||
|DROP|9536|91.17|89.97|
|
||||
|GSM8K|1319|96.21|95.98|
|
||||
|CMath|600|96.50|96.00|
|
||||
|CMMLU|11477|89.92|86.22|
|
@@ -52,7 +52,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
|
||||
|
||||
# Sampling class ("base", "air", or "rejection")
|
||||
# Sampling class ("base", "base_non_truncated", "air", or "rejection")
|
||||
"FD_SAMPLING_CLASS":
|
||||
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
|
||||
@@ -67,6 +67,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Switch from standalone PD to centralized inference (0 or 1)
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
|
||||
|
||||
}
|
||||
```
|
||||
```
|
||||
|
@@ -5,8 +5,8 @@
|
||||
##目前支持思考链的模型
|
||||
| 模型名称 | 解析器名称 | 默认开启思考链 |
|
||||
|---------------|-------------|---------|
|
||||
| ernie-45-vl | ernie-45-vl | ✓ |
|
||||
| ernie-lite-vl | ernie-45-vl | ✓ |
|
||||
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | ernie-45-vl | ✓ |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✓ |
|
||||
|
||||
思考模型需要指定解析器,以便于对思考内容进行解析. 通过`enable_thinking=False` 参数可以关闭模型思考模式.
|
||||
|
||||
|
@@ -100,6 +100,7 @@ for output in outputs:
|
||||
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数(>1时惩罚重复,<1时鼓励重复)
|
||||
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
|
||||
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
|
||||
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
|
||||
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
|
||||
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
|
||||
|
||||
|
@@ -61,7 +61,7 @@ FastDeploy 与 OpenAI 协议的请求参数差异如下,其余请求参数会
|
||||
- `stream_options`: Optional[StreamOptions] = None
|
||||
- `temperature`: Optional[float] = None
|
||||
- `top_p`: Optional[float] = None
|
||||
- `metadata`: Optional[dict] = None (仅在v1/chat/compeltions中支持,用于配置额外参数, 如meta_data={"enable_thinking": True})
|
||||
- `metadata`: Optional[dict] = None (仅在v1/chat/compeltions中支持,用于配置额外参数, 如metadata={"enable_thinking": True})
|
||||
- `min_tokens`: Optional[int] = 1 最小生成的Token个数
|
||||
- `reasoning_max_tokens`: Optional[int] = None 思考内容最大Token数,默认与max_tokens一致
|
||||
- `enable_thinking`: Optional[bool] = True 支持深度思考的模型是否打开思考
|
||||
|
@@ -26,7 +26,7 @@
|
||||
| ```kv_cache_ratio``` | `float` | KVCache块按kv_cache_ratio比例分给Prefill阶段和Decode阶段, 默认0.75 |
|
||||
| ```enable_prefix_caching``` | `bool` | 是否开启Prefix Caching,默认False |
|
||||
| ```swap_space``` | `float` | 开启Prefix Caching时,用于swap KVCache的CPU内存大小,单位GB,默认None |
|
||||
| ```enable_chunk_prefill``` | `bool` | 开启Chunked Prefill,默认False |
|
||||
| ```enable_chunked_prefill``` | `bool` | 开启Chunked Prefill,默认False |
|
||||
| ```max_num_partial_prefills``` | `int` | 开启Chunked Prefill时,Prefill阶段的最大并发数,默认1 |
|
||||
| ```max_long_partial_prefills``` | `int` | 开启Chunked Prefill时,Prefill阶段并发中包启的最多长请求数,默认1 |
|
||||
| ```long_prefill_token_threshold``` | `int` | 开启Chunked Prefill时,请求Token数超过此值的请求被视为长请求,默认为max_model_len*0.04 |
|
||||
@@ -113,5 +113,5 @@ FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算
|
||||
...
|
||||
```
|
||||
- 当开启 ```use_cudagraph``` 时,暂时只支持单卡推理,即 ```tensor_parallel_size``` 设为1。
|
||||
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunk_prefill``` 。
|
||||
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunked_prefill``` 。
|
||||
- 当开启 ```use_cudagraph``` 后,size小于等于 ```max_capture_batch_size``` 的batch会由CudaGraph来执行前向计算,大于 ```max_capture_batch_size``` 的batch会由原本的动态图/静态图执行前向计算。如果希望所有batch size均由CudaGraph来执行,```max_capture_batch_size``` 的值建议与 ```max_num_seqs``` 一致。```max_capture_batch_size``` 大于 ```max_num_seqs``` 会导致浪费,会多捕获一些推理时不会遇到的batch,占用更多时间与显存。
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# FastDeploy 环境变量说明
|
||||
FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明:
|
||||
|
||||
```python
|
||||
environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 构建 FastDeploy 时使用的 CUDA 架构版本,这是一个字符串列表,例如[80,90]
|
||||
@@ -50,7 +51,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
|
||||
|
||||
# 设置采样类别,当前可设置为 "base"、"air" 或 "rejection"
|
||||
# 设置采样类别,当前可设置为 "base"、"base_non_truncated"、"air" 或 "rejection"
|
||||
"FD_SAMPLING_CLASS":
|
||||
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
|
||||
@@ -65,6 +66,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 是否从单机 PD 分离转换为集中式推理
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
|
||||
|
||||
}
|
||||
```
|
||||
```
|
||||
|
@@ -15,11 +15,14 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# suppress warning log from paddlepaddle
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
# suppress log from aistudio
|
||||
os.environ["AISTUDIO_LOG"] = "critical"
|
||||
from fastdeploy.utils import version
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
|
||||
@@ -30,3 +33,48 @@ try:
|
||||
use_triton_in_paddle.make_triton_compatible_with_paddle()
|
||||
except ImportError:
|
||||
pass
|
||||
# TODO(tangbinhan): remove this code
|
||||
|
||||
|
||||
def _patch_fastsafetensors():
|
||||
try:
|
||||
file_path = subprocess.check_output([
|
||||
sys.executable, "-c", "import fastsafetensors, os; \
|
||||
print(os.path.join(os.path.dirname(fastsafetensors.__file__), \
|
||||
'frameworks', '_paddle.py'))"
|
||||
]).decode().strip()
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
content = f.read()
|
||||
if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content:
|
||||
return
|
||||
|
||||
modified = False
|
||||
if "DType.U16: DType.BF16," not in content:
|
||||
lines = content.splitlines()
|
||||
new_lines = []
|
||||
inside_block = False
|
||||
for line in lines:
|
||||
new_lines.append(line)
|
||||
if 'need_workaround_dtypes: Dict[DType, DType] = {' in line:
|
||||
inside_block = True
|
||||
elif inside_block and '}' in line:
|
||||
new_lines.insert(-1, ' DType.U16: DType.BF16,')
|
||||
inside_block = False
|
||||
modified = True
|
||||
content = "\n".join(new_lines)
|
||||
|
||||
if "DType.I8: paddle.uint8," in content:
|
||||
content = content.replace("DType.I8: paddle.uint8,",
|
||||
"DType.U8: paddle.uint8,")
|
||||
modified = True
|
||||
|
||||
if modified:
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content + "\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to patch fastsafetensors: {e}")
|
||||
|
||||
|
||||
_patch_fastsafetensors()
|
||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
@@ -51,7 +51,6 @@ class ModelConfig(PretrainedConfig):
|
||||
top_p = 0.0
|
||||
temperature = 1.0
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
penalty_score = 1.0
|
||||
frequency_score = 0.0
|
||||
presence_score = 0.0
|
||||
@@ -70,7 +69,6 @@ class ModelConfig(PretrainedConfig):
|
||||
max_seq_len: int = 512,
|
||||
initializer_range: float = 0.02,
|
||||
use_rope=True,
|
||||
use_fast_ffn: bool = False,
|
||||
rope_theta: int = 10000,
|
||||
rope_3d: bool = False,
|
||||
ori_vocab_size: int | None = None,
|
||||
@@ -86,6 +84,7 @@ class ModelConfig(PretrainedConfig):
|
||||
head_dim: Optional[int] = None,
|
||||
tie_word_embeddings: bool = False,
|
||||
is_quantized: bool = False,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -105,7 +104,6 @@ class ModelConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.use_rope = use_rope
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.rope_theta = rope_theta
|
||||
self.ori_vocab_size = ori_vocab_size or vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
@@ -126,6 +124,7 @@ class ModelConfig(PretrainedConfig):
|
||||
self.dtype = dtype
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.is_quantized = is_quantized
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -142,6 +141,7 @@ class MoEConfig:
|
||||
moe_num_shared_experts = (0, )
|
||||
moe_layer_start_index = 0
|
||||
moe_layer_end_index = None
|
||||
moe_use_aux_free: bool = False
|
||||
num_max_dispatch_tokens_per_rank = 256
|
||||
im_patch_id = (
|
||||
100295 # multimodality, TODO(liuyuanle): read from config.json
|
||||
@@ -163,7 +163,6 @@ class ParallelConfig:
|
||||
# The embedding weight distributed on your gpu cards is divided by row or column.
|
||||
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
|
||||
# but hidden_size can, we can consider split embedding weight by column.
|
||||
column_cut = False # (bool, optional)
|
||||
"""
|
||||
From old wersion worker args
|
||||
TODO(gongshaotian): Reclassify
|
||||
@@ -194,18 +193,13 @@ class ParallelConfig:
|
||||
engine_pid: Optional[int] = None
|
||||
# Do profile or not
|
||||
do_profile: bool = False
|
||||
# Dynamic load weight or not
|
||||
dynamic_load_weight: bool = False
|
||||
#
|
||||
pad_token_id: int = -1
|
||||
#
|
||||
eos_tokens_lens: int = 2
|
||||
# Enable chunked prefill
|
||||
enable_chunked_prefill: str = "store_true"
|
||||
"""
|
||||
- APPEND_ATTN:
|
||||
"""
|
||||
attention_backend: str = "APPEND_ATTN"
|
||||
|
||||
max_num_batched_tokens: int = 2048
|
||||
# enable prefix cache
|
||||
enable_prefix_caching = None
|
||||
@@ -354,9 +348,27 @@ class GraphOptimizationConfig:
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
Configuration for loading parameter
|
||||
Configuration for dynamic weight loading strategies
|
||||
|
||||
Attributes:
|
||||
dynamic_load_weight: Whether to enable dynamic weight loading
|
||||
load_strategy: Specifies the weight loading method when enabled:
|
||||
- 'ipc': Real-time IPC streaming with automatic resharding
|
||||
- 'ipc_no_reshard': Real-time IPC streaming without weight process
|
||||
- 'ipc_snapshot': Load from disk snapshot of IPC weights
|
||||
- 'meta': provide RL traing worker, no_weights_load
|
||||
- None: No dynamic loading
|
||||
"""
|
||||
pass
|
||||
use_fastsafetensor: bool = False
|
||||
dynamic_load_weight: bool = False
|
||||
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.load_strategy is not None and not self.dynamic_load_weight:
|
||||
raise ValueError("Load strategy requires dynamic_load_weight=True")
|
||||
|
||||
if self.dynamic_load_weight and self.load_strategy is None:
|
||||
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -392,7 +404,7 @@ class FDConfig:
|
||||
init=True) # type: ignore
|
||||
device_config: DeviceConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True)
|
||||
quant_config: Optional[QuantConfigBase] = None
|
||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||
moe_config: MoEConfig = field(default=None, init=True) # type: ignore
|
||||
|
@@ -16,48 +16,54 @@
|
||||
|
||||
import time
|
||||
import os
|
||||
import subprocess
|
||||
import signal
|
||||
import multiprocessing
|
||||
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
|
||||
model_name_or_path = "./models/eb45t02/"
|
||||
|
||||
model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle"
|
||||
|
||||
|
||||
|
||||
prefill_cmd = (f"FD_LOG_DIR=log_prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python fastdeploy.entrypoints.openai.api_server.py"
|
||||
+ f" --model {model_name_or_path} --port 9811"
|
||||
+ f" --splitwise-role prefill --tensor-parallel-size 4"
|
||||
+ f" --engine-worker-queue-port 6676 --cache-queue-port 55663")
|
||||
def start_decode(model_name_or_path):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
os.environ["FD_LOG_DIR"] = "log_decode"
|
||||
llm_decode = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=1,
|
||||
splitwise_role="decode",
|
||||
engine_worker_queue_port=6678,
|
||||
innode_prefill_ports=[6676],
|
||||
cache_queue_port=55668
|
||||
)
|
||||
return llm_decode
|
||||
|
||||
prefill_instance = subprocess.Popen(
|
||||
prefill_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
def start_prefill(model_name_or_path):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["FD_LOG_DIR"] = "log_prefill"
|
||||
llm_prefill = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=1,
|
||||
splitwise_role="prefill",
|
||||
engine_worker_queue_port=6677,
|
||||
cache_queue_port=55667,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
prefill = multiprocessing.Process(
|
||||
target=start_prefill,
|
||||
args=(model_name_or_path,)).start()
|
||||
time.sleep(10)
|
||||
llm_decode = start_decode(model_name_or_path)
|
||||
|
||||
output = llm_decode.generate(prompts=["who are you?", "what can you do?"], use_tqdm=True)
|
||||
print(output)
|
||||
|
||||
decode.join()
|
||||
|
||||
|
||||
# # 超参设置
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
|
||||
os.environ["FD_LOG_DIR"] = "log_decode"
|
||||
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
|
||||
llm_decode = LLM(
|
||||
model=model_name_or_path,
|
||||
tensor_parallel_size=4,
|
||||
splitwise_role="decode",
|
||||
engine_worker_queue_port=6678,
|
||||
innode_prefill_ports=[6676],
|
||||
cache_queue_port=55668
|
||||
)
|
||||
|
||||
|
||||
output = llm_decode.generate(prompts=["who are you?", "what can you do?"], use_tqdm=True)
|
||||
print(output)
|
||||
|
||||
|
||||
os.killpg(prefill_instance.pid, signal.SIGTERM)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -17,13 +17,15 @@
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
try:
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
except:
|
||||
tensor_model_parallel_all_reduce=None
|
@@ -87,10 +87,14 @@ class EngineArgs:
|
||||
"""
|
||||
Configuration for speculative execution.
|
||||
"""
|
||||
dynamic_load_weight: int = 0
|
||||
dynamic_load_weight: bool = False
|
||||
"""
|
||||
dynamic load weight
|
||||
"""
|
||||
load_strategy: str = "meta"
|
||||
"""
|
||||
dynamic load weight strategy
|
||||
"""
|
||||
quantization: str = None
|
||||
guided_decoding_backend: str = "off"
|
||||
"""
|
||||
@@ -292,6 +296,12 @@ class EngineArgs:
|
||||
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
|
||||
"""
|
||||
|
||||
enable_logprob: bool = False
|
||||
"""
|
||||
Flag to enable logprob output. Default is False (disabled).
|
||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -364,13 +374,16 @@ class EngineArgs:
|
||||
type=json.loads,
|
||||
default=EngineArgs.speculative_config,
|
||||
help="Configuration for speculative execution.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--dynamic-load-weight",
|
||||
type=int,
|
||||
action='store_true',
|
||||
default=EngineArgs.dynamic_load_weight,
|
||||
help="Flag to indicate whether to load weight dynamically.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--load-strategy",
|
||||
type=str,
|
||||
default=EngineArgs.load_strategy,
|
||||
help="Flag to dynamic load strategy.")
|
||||
model_group.add_argument("--engine-worker-queue-port",
|
||||
type=int,
|
||||
default=EngineArgs.engine_worker_queue_port,
|
||||
@@ -383,6 +396,7 @@ class EngineArgs:
|
||||
"default is None. The priority of this configuration "\
|
||||
"is lower than that of the config file. " \
|
||||
"More complex quantization methods need to be configured via the config file.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--enable-static-graph-inference",
|
||||
action='store_true',
|
||||
@@ -408,6 +422,11 @@ class EngineArgs:
|
||||
help=
|
||||
"Disabled any whitespaces when using guided decoding backend XGrammar."
|
||||
)
|
||||
model_group.add_argument("--enable-logprob",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_logprob,
|
||||
help="Enable output of token-level log probabilities."
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
@@ -668,8 +687,9 @@ class EngineArgs:
|
||||
"""
|
||||
return ModelConfig(model_name_or_path=self.model,
|
||||
config_json_file=self.model_config_name,
|
||||
quantization=self.quantization,
|
||||
dynamic_load_weight=self.dynamic_load_weight,
|
||||
quantization=self.quantization)
|
||||
load_strategy=self.load_strategy)
|
||||
|
||||
def create_cache_config(self, model_cfg) -> CacheConfig:
|
||||
"""
|
||||
@@ -749,6 +769,9 @@ class EngineArgs:
|
||||
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
|
||||
assert not (self.use_cudagraph and self.enable_prefix_caching), \
|
||||
"Prefix caching cannot be used with CUDA graph"
|
||||
|
||||
return Config(
|
||||
model_name_or_path=self.model,
|
||||
model_config=model_cfg,
|
||||
@@ -779,4 +802,5 @@ class EngineArgs:
|
||||
max_capture_batch_size=self.max_capture_batch_size,
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
enable_logprob = self.enable_logprob,
|
||||
)
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
@@ -41,7 +42,8 @@ class ModelConfig:
|
||||
def __init__(self,
|
||||
model_name_or_path: str,
|
||||
config_json_file: str = "config.json",
|
||||
dynamic_load_weight: int = 0,
|
||||
dynamic_load_weight: bool = False,
|
||||
load_strategy: str="meta",
|
||||
quantization: str = None,
|
||||
download_dir: Optional[str] = None):
|
||||
"""
|
||||
@@ -55,6 +57,7 @@ class ModelConfig:
|
||||
self.model_dir = model_name_or_path
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
|
||||
self.dynamic_load_weight = dynamic_load_weight
|
||||
self.load_strategy = load_strategy
|
||||
self.quantization = quantization
|
||||
|
||||
config_file = os.path.join(model_name_or_path, config_json_file)
|
||||
@@ -465,7 +468,63 @@ class ParallelConfig:
|
||||
llm_logger.info("Parallel Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("==================")
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitConfig:
|
||||
"""
|
||||
Configuration for tracking version information from version.txt
|
||||
|
||||
Attributes:
|
||||
fastdeploy_commit: Full FastDeploy git commit hash
|
||||
paddle_version: PaddlePaddle version string
|
||||
paddle_commit: PaddlePaddle git commit hash
|
||||
cuda_version: CUDA version string
|
||||
compiler_version: CXX compiler version string
|
||||
"""
|
||||
fastdeploy_commit: str = ""
|
||||
paddle_version: str = ""
|
||||
paddle_commit: str = ""
|
||||
cuda_version: str = ""
|
||||
compiler_version: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Automatically load version info when initialized"""
|
||||
self._load_from_version_file()
|
||||
|
||||
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
|
||||
"""Internal method to load version info from file"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("fastdeploy GIT COMMIT ID:"):
|
||||
self.fastdeploy_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle version:"):
|
||||
self.paddle_version = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle GIT COMMIT ID:"):
|
||||
self.paddle_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("CUDA version:"):
|
||||
self.cuda_version = line.split(":")[1].strip()
|
||||
elif line.startswith("CXX compiler version:"):
|
||||
self.compiler_version = line.split(":")[1].strip()
|
||||
except FileNotFoundError:
|
||||
llm_logger.info(f"Warning: Version file not found at {file_path}")
|
||||
except Exception as e:
|
||||
llm_logger.info(f"Warning: Could not read version file - {str(e)}")
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Fasedeploy Commit Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -500,6 +559,7 @@ class Config:
|
||||
cache_config: CacheConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
commit_config: CommitConfig = CommitConfig(),
|
||||
model_name_or_path: str = None,
|
||||
tokenizer: str = None,
|
||||
tensor_parallel_size: int = 8,
|
||||
@@ -525,6 +585,7 @@ class Config:
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
enable_logprob: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Config class.
|
||||
@@ -559,6 +620,7 @@ class Config:
|
||||
self.cache_config = cache_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.parallel_config = parallel_config
|
||||
self.commit_config = commit_config
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -584,12 +646,10 @@ class Config:
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
|
||||
|
||||
if self.innode_prefill_ports is not None:
|
||||
if not isinstance(self.innode_prefill_ports, list):
|
||||
ports = str(self.innode_prefill_ports).split(',')
|
||||
self.innode_prefill_ports = [int(port) for port in ports]
|
||||
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
@@ -619,6 +679,8 @@ class Config:
|
||||
self.parallel_config.expert_parallel_size), 8))])
|
||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||
|
||||
self.enable_logprob = enable_logprob
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.check()
|
||||
@@ -728,7 +790,7 @@ class Config:
|
||||
), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
pass
|
||||
import xgrammar # noqa
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
@@ -749,7 +811,11 @@ class Config:
|
||||
if k == "generation_config" and v is not None:
|
||||
for gck, gcv in v.to_dict().items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
|
||||
elif (k == "cache_config" or
|
||||
k == "model_config" or
|
||||
k == "scheduler_config" or
|
||||
k == "parallel_config" or
|
||||
k == "commit_config"):
|
||||
v.print()
|
||||
else:
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
|
@@ -47,7 +47,8 @@ from fastdeploy.output.token_processor import (TokenProcessor,
|
||||
WarmUpTokenProcessor)
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
|
||||
from fastdeploy.metrics.trace_util import extract_from_metadata, start_span, start_span_request
|
||||
from opentelemetry import trace
|
||||
|
||||
class LLMEngine(object):
|
||||
"""
|
||||
@@ -165,12 +166,6 @@ class LLMEngine(object):
|
||||
disable_any_whitespace=self.cfg.disable_any_whitespace,
|
||||
)
|
||||
|
||||
def reset_scheduler(self):
|
||||
"""
|
||||
Reset the scheduler to its initial state.
|
||||
"""
|
||||
self.scheduler.reset()
|
||||
|
||||
def start(self, api_server_pid=None):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
@@ -286,6 +281,8 @@ class LLMEngine(object):
|
||||
while self.running:
|
||||
try:
|
||||
results = self.scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
time.sleep(0.001)
|
||||
for request_id, contents in results.items():
|
||||
for result in contents:
|
||||
self.zmq_server.send_multipart(request_id, result)
|
||||
@@ -379,7 +376,10 @@ class LLMEngine(object):
|
||||
request, insert_task = None, []
|
||||
results: List[Tuple[str, Optional[str]]] = list()
|
||||
if data:
|
||||
request = Request.from_dict(data)
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
|
||||
|
||||
llm_logger.debug(f"Receive request: {request}")
|
||||
|
||||
err_msg = None
|
||||
@@ -444,8 +444,8 @@ class LLMEngine(object):
|
||||
enable_thinking = None
|
||||
if kwargs is not None:
|
||||
enable_thinking = kwargs.get("enable_thinking", None)
|
||||
request = self.data_processor.process_request(request,
|
||||
self.cfg.max_model_len, enable_thinking=enable_thinking)
|
||||
request = self.data_processor.process_request(
|
||||
request, self.cfg.max_model_len, enable_thinking=enable_thinking)
|
||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||
input_ids_len = request.prompt_token_ids_len
|
||||
request.set(
|
||||
@@ -453,7 +453,8 @@ class LLMEngine(object):
|
||||
min(self.cfg.max_model_len - input_ids_len,
|
||||
request.get("max_tokens")))
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
|
||||
default_reasoning_max_tokens = max(
|
||||
int(request.get("max_tokens") * 0.8), 1)
|
||||
request.set("reasoning_max_tokens", default_reasoning_max_tokens)
|
||||
min_tokens = request.get("min_tokens")
|
||||
if input_ids_len + min_tokens >= self.cfg.max_model_len:
|
||||
@@ -709,6 +710,8 @@ class LLMEngine(object):
|
||||
"""
|
||||
Insert tasks to engine.
|
||||
"""
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
# TODO 返回至 scheduler
|
||||
if allocated:
|
||||
current_tasks = []
|
||||
@@ -963,8 +966,8 @@ class LLMEngine(object):
|
||||
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
|
||||
"FLAGS_use_append_attn": 1,
|
||||
"NCCL_ALGO": "Ring",
|
||||
"FLAGS_hardamard_moe_block_size": 128,
|
||||
"FLAGS_max_partition_size": 32768,
|
||||
"FLAGS_hardamard_moe_block_size": 128,
|
||||
}
|
||||
# environment variables needed by Dy2St
|
||||
variables.update({
|
||||
@@ -1017,6 +1020,12 @@ class LLMEngine(object):
|
||||
worker_path = "../worker/vl_worker_process.py"
|
||||
py_script = os.path.join(current_dir_path, worker_path)
|
||||
|
||||
ori_vocab_size = (
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, 'sp_model')
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
)
|
||||
|
||||
arguments = (
|
||||
f" --nnodes {str(self.cfg.nnode)}"
|
||||
f" --devices {self.cfg.device_ids} {py_script}"
|
||||
@@ -1037,13 +1046,14 @@ class LLMEngine(object):
|
||||
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
|
||||
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
|
||||
f" --quantization {self.cfg.model_config.quantization}"
|
||||
f" --ori_vocab_size {len(self.data_processor.tokenizer)}"
|
||||
f" --ori_vocab_size {ori_vocab_size}"
|
||||
f" --speculative_method {self.cfg.speculative_config.method}"
|
||||
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
|
||||
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
||||
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
||||
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}")
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.model_config.load_strategy}")
|
||||
|
||||
worker_append_flag = {
|
||||
"enable_expert_parallel":
|
||||
@@ -1058,6 +1068,7 @@ class LLMEngine(object):
|
||||
self.cfg.enable_static_graph_inference,
|
||||
"use_cudagraph": self.cfg.use_cudagraph,
|
||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||
"enable_logprob": self.cfg.enable_logprob,
|
||||
}
|
||||
for worker_flag, value in worker_append_flag.items():
|
||||
if value:
|
||||
@@ -1188,8 +1199,9 @@ class LLMEngine(object):
|
||||
line = line.decode('utf-8', errors='ignore')
|
||||
if self.worker_init_status.get("finished", False):
|
||||
break
|
||||
if match := re.search(r'Loading checkpoint shards:\s*(\d+)',
|
||||
line):
|
||||
if match := re.search(
|
||||
r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)',
|
||||
line):
|
||||
self.worker_init_status["weight_loadding"] = eval(
|
||||
match.group(1)) * 1.0 / 100
|
||||
elif (match := re.search(r'Start load layer (\d+)',
|
||||
|
@@ -24,6 +24,7 @@ import numpy
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -54,7 +55,8 @@ class Request:
|
||||
guided_grammar: Optional[Any] = None,
|
||||
structural_tag: Optional[Any] = None,
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True) -> None:
|
||||
enable_thinking: Optional[bool] = True,
|
||||
trace_carrier: dict = dict()) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
@@ -90,6 +92,7 @@ class Request:
|
||||
self.multimodal_data = multimodal_data
|
||||
|
||||
self.enable_thinking = enable_thinking
|
||||
self.trace_carrier = trace_carrier
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@@ -119,7 +122,8 @@ class Request:
|
||||
guided_grammar=d.get("guided_grammar", None),
|
||||
structural_tag=d.get("structural_tag", None),
|
||||
guided_json_object=d.get("guided_json_object", None),
|
||||
enable_thinking=d.get("enable_thinking", True))
|
||||
enable_thinking=d.get("enable_thinking", True),
|
||||
trace_carrier=d.get("trace_carrier", {}))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""convert Request into a serializable dict """
|
||||
@@ -141,7 +145,8 @@ class Request:
|
||||
"raw_request": self.raw_request,
|
||||
"disaggregate_info": self.disaggregate_info,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"enable_thinking": self.enable_thinking
|
||||
"enable_thinking": self.enable_thinking,
|
||||
"trace_carrier": self.trace_carrier
|
||||
}
|
||||
add_params = [
|
||||
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
|
||||
@@ -189,6 +194,8 @@ class CompletionOutput:
|
||||
index: int
|
||||
send_idx: int
|
||||
token_ids: list[int]
|
||||
logprob: Optional[float] = None
|
||||
top_logprobs: Optional[LogprobsLists] = None
|
||||
draft_token_ids: list[int] = None
|
||||
text: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
@@ -201,6 +208,8 @@ class CompletionOutput:
|
||||
"index": self.index,
|
||||
"send_idx": self.send_idx,
|
||||
"token_ids": self.token_ids,
|
||||
"logprob": self.logprob,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"text": self.text,
|
||||
"reasoning_content": self.reasoning_content
|
||||
|
@@ -52,6 +52,7 @@ class SamplingParams:
|
||||
the model more random. Zero means greedy sampling.
|
||||
top_p: Float that controls the cumulative probability of the top tokens
|
||||
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
||||
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
||||
seed: Random seed to use for the generation.
|
||||
stop: list of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
@@ -82,6 +83,7 @@ class SamplingParams:
|
||||
repetition_penalty: float = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
top_k: int = 0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
@@ -111,6 +113,7 @@ class SamplingParams:
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
@@ -129,7 +132,8 @@ class SamplingParams:
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
top_p=top_p,
|
||||
top_k=top_k if top_k is not None else 0,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
@@ -169,6 +173,13 @@ class SamplingParams:
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
|
||||
# quietly accept -1 as disabled, but prefer 0
|
||||
if self.top_k < -1:
|
||||
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
|
||||
f"got {self.top_k}.")
|
||||
if not isinstance(self.top_k, int):
|
||||
raise TypeError(
|
||||
f"top_k must be an integer, got {type(self.top_k).__name__}")
|
||||
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
@@ -188,6 +199,9 @@ class SamplingParams:
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20:
|
||||
raise ValueError(
|
||||
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got "
|
||||
|
@@ -24,6 +24,7 @@ import zmq
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import CONTENT_TYPE_LATEST
|
||||
from fastdeploy.metrics.trace_util import inject_to_metadata,instrument
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
@@ -32,7 +33,8 @@ from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
ErrorResponse,
|
||||
ControlSchedulerRequest)
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import \
|
||||
OpenAIServingCompletion
|
||||
@@ -44,6 +46,7 @@ from fastdeploy.utils import (FlexibleArgumentParser, api_server_logger,
|
||||
console_logger, is_port_available,
|
||||
retrive_model_from_server)
|
||||
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--port",
|
||||
default=8000,
|
||||
@@ -139,6 +142,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
instrument(app)
|
||||
|
||||
|
||||
# TODO 传递真实引擎值 通过pid 获取状态
|
||||
@@ -209,6 +213,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
return JSONResponse(
|
||||
content={"error": "Worker Service Not Healthy"},
|
||||
status_code=304)
|
||||
inject_to_metadata(request)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
@@ -273,10 +278,13 @@ def clear_load_weight(request: Request) -> Response:
|
||||
status_code=404)
|
||||
|
||||
|
||||
def launch_api_server(args) -> None:
|
||||
def launch_api_server() -> None:
|
||||
"""
|
||||
启动http服务
|
||||
"""
|
||||
if not is_port_available(args.host, args.port):
|
||||
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
||||
|
||||
api_server_logger.info(
|
||||
f"launch Fastdeploy api server... port: {args.port}")
|
||||
api_server_logger.info(f"args: {args.__dict__}")
|
||||
@@ -319,6 +327,11 @@ def run_metrics_server():
|
||||
|
||||
def launch_metrics_server():
|
||||
"""Metrics server running the sub thread"""
|
||||
if not is_port_available(args.host, args.metrics_port):
|
||||
raise Exception(
|
||||
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
|
||||
)
|
||||
|
||||
prom_dir = cleanup_prometheus_files(True)
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
|
||||
metrics_server_thread = threading.Thread(target=run_metrics_server,
|
||||
@@ -339,10 +352,39 @@ def reset_scheduler():
|
||||
|
||||
if llm_engine is None:
|
||||
return Response("Engine not loaded", status_code=500)
|
||||
llm_engine.reset_scheduler()
|
||||
llm_engine.scheduler.reset()
|
||||
return Response("Scheduler Reset Successfully", status_code=200)
|
||||
|
||||
|
||||
@controller_app.post("/controller/scheduler")
|
||||
def control_scheduler(request: ControlSchedulerRequest):
|
||||
"""
|
||||
Control the scheduler behavior with the given parameters.
|
||||
"""
|
||||
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
|
||||
|
||||
global llm_engine
|
||||
if llm_engine is None:
|
||||
content.message = "Engine is not loaded"
|
||||
content.code = 500
|
||||
return JSONResponse(content=content.model_dump(), status_code=500)
|
||||
|
||||
if request.reset:
|
||||
llm_engine.scheduler.reset()
|
||||
|
||||
if request.load_shards_num or request.reallocate_shard:
|
||||
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
|
||||
llm_engine.scheduler.update_config(
|
||||
load_shards_num=request.load_shards_num,
|
||||
reallocate=request.reallocate_shard)
|
||||
else:
|
||||
content.message="This scheduler doesn't support the `update_config()` method."
|
||||
content.code=400
|
||||
return JSONResponse(content=content.model_dump(), status_code=400)
|
||||
|
||||
return JSONResponse(content=content.model_dump(), status_code=200)
|
||||
|
||||
|
||||
def run_controller_server():
|
||||
"""
|
||||
run controller server
|
||||
@@ -358,6 +400,11 @@ def launch_controller_server():
|
||||
if args.controller_port < 0:
|
||||
return
|
||||
|
||||
if not is_port_available(args.host, args.controller_port):
|
||||
raise Exception(
|
||||
f"The parameter `controller_port`:{args.controller_port} is already in use."
|
||||
)
|
||||
|
||||
controller_server_thread = threading.Thread(target=run_controller_server,
|
||||
daemon=True)
|
||||
controller_server_thread.start()
|
||||
@@ -366,19 +413,13 @@ def launch_controller_server():
|
||||
|
||||
def main():
|
||||
"""main函数"""
|
||||
if not is_port_available(args.host, args.port):
|
||||
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
||||
if not is_port_available(args.host, args.metrics_port):
|
||||
raise Exception(
|
||||
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
|
||||
)
|
||||
|
||||
if load_engine() is None:
|
||||
return
|
||||
|
||||
launch_controller_server()
|
||||
launch_metrics_server()
|
||||
launch_api_server(args)
|
||||
launch_api_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -122,7 +122,8 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
@@ -136,6 +137,21 @@ class ChatCompletionResponse(BaseModel):
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
class LogProbEntry(BaseModel):
|
||||
"""
|
||||
Log probability entry.
|
||||
"""
|
||||
token: str
|
||||
logprob: float
|
||||
bytes: Optional[List[int]] = None
|
||||
top_logprobs: Optional[List["LogProbEntry"]] = None
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
"""
|
||||
LogProbs.
|
||||
"""
|
||||
content: Optional[List[LogProbEntry]] = None
|
||||
refusal: Optional[Union[str, None]] = None
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""
|
||||
@@ -154,6 +170,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
|
||||
@@ -292,6 +309,7 @@ class CompletionRequest(BaseModel):
|
||||
suffix: Optional[dict] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
@@ -391,6 +409,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
model: Optional[str] = "default"
|
||||
frequency_penalty: Optional[float] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
# remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
@@ -405,6 +425,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@@ -432,6 +453,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
if request_id is not None:
|
||||
req_dict['request_id'] = request_id
|
||||
|
||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||
|
||||
if self.metadata is not None:
|
||||
for key, value in self.metadata.items():
|
||||
req_dict[key] = value
|
||||
@@ -503,3 +527,27 @@ class ChatCompletionRequest(BaseModel):
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if top_logprobs > 0 and not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ControlSchedulerRequest(BaseModel):
|
||||
"""
|
||||
Control scheduler request to the engine.
|
||||
"""
|
||||
reset: Optional[bool] = False
|
||||
load_shards_num: Optional[int] = None
|
||||
reallocate_shard: Optional[bool] = False
|
@@ -15,34 +15,23 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Callable, Optional, Union, List
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
PromptTokenUsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbEntry, LogProbs, PromptTokenUsageInfo, UsageInfo)
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
|
||||
from fastdeploy.utils import api_server_logger
|
||||
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class OpenAIServingChat:
|
||||
@@ -115,6 +104,7 @@ class OpenAIServingChat:
|
||||
num_choices = 1
|
||||
max_streaming_response_tokens = 1
|
||||
enable_thinking = None
|
||||
include_stop_str_in_output = False
|
||||
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
|
||||
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
|
||||
|
||||
@@ -157,14 +147,15 @@ class OpenAIServingChat:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
|
||||
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
@@ -200,6 +191,19 @@ class OpenAIServingChat:
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs= request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
@@ -208,12 +212,15 @@ class OpenAIServingChat:
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
@@ -221,6 +228,9 @@ class OpenAIServingChat:
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
|
||||
choice.delta.token_ids = output["token_ids"]
|
||||
if include_continuous_usage:
|
||||
@@ -274,6 +284,7 @@ class OpenAIServingChat:
|
||||
created_time = int(time.time())
|
||||
final_res = None
|
||||
enable_thinking = None
|
||||
include_stop_str_in_output = False
|
||||
try:
|
||||
dealer = await aiozmq.create_zmq_stream(
|
||||
zmq.DEALER,
|
||||
@@ -283,6 +294,7 @@ class OpenAIServingChat:
|
||||
final_res = None
|
||||
previous_num_tokens = 0
|
||||
current_waiting_time = 0
|
||||
logprob_contents = []
|
||||
while True:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
@@ -303,10 +315,27 @@ class OpenAIServingChat:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
break
|
||||
@@ -322,19 +351,30 @@ class OpenAIServingChat:
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
token_ids=output.get("token_ids")
|
||||
)
|
||||
logprobs_full_res = None
|
||||
if logprob_contents:
|
||||
logprobs_full_res = LogProbs(
|
||||
content=logprob_contents
|
||||
)
|
||||
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
logprobs=logprobs_full_res,
|
||||
finish_reason=None
|
||||
)
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
choices.append(choice)
|
||||
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
@@ -353,3 +393,55 @@ class OpenAIServingChat:
|
||||
choices=choices,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def build_logprobs_response(
|
||||
self,
|
||||
request_logprobs: bool,
|
||||
response_logprobs: Optional[LogprobsLists],
|
||||
request_top_logprobs: int,
|
||||
) -> Optional[LogProbs]:
|
||||
"""
|
||||
Construct a logprobs response object in line with the OpenAI style.
|
||||
Retain the complete top-k candidates and avoid circular references.
|
||||
"""
|
||||
|
||||
# Parameter validation
|
||||
if (
|
||||
response_logprobs is None
|
||||
or not request_logprobs
|
||||
or request_top_logprobs is None
|
||||
or request_top_logprobs < 0
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
# The top-k candidates for the current token
|
||||
topk_token_ids = response_logprobs.logprob_token_ids[0][:request_top_logprobs + 1]
|
||||
topk_logprobs = response_logprobs.logprobs[0][:request_top_logprobs + 1]
|
||||
|
||||
# Construct the candidate token structure (LogProbEntry) of topk
|
||||
top_logprob_entries: List[LogProbEntry] = []
|
||||
for tid, lp in zip(topk_token_ids, topk_logprobs):
|
||||
token_str = self.engine_client.data_processor.process_logprob_response([tid],
|
||||
clean_up_tokenization_spaces=False)
|
||||
# token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
entry = LogProbEntry(
|
||||
token=token_str,
|
||||
logprob=lp,
|
||||
# bytes=list(token_bytes)
|
||||
)
|
||||
top_logprob_entries.append(entry)
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
sampled_entry = LogProbEntry(
|
||||
token=top_logprob_entries[0].token,
|
||||
logprob=top_logprob_entries[0].logprob,
|
||||
bytes=top_logprob_entries[0].bytes,
|
||||
top_logprobs=top_logprob_entries[1:] # Here are the complete topk candidates
|
||||
)
|
||||
|
||||
return LogProbs(content=[sampled_entry])
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in build_logprobs_response: %s", e)
|
||||
api_server_logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
@@ -74,7 +74,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
|
||||
|
||||
# Set sampling class. "base", "air" and "rejection" can be set currently.
|
||||
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
|
||||
"FD_SAMPLING_CLASS":
|
||||
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
|
||||
@@ -82,13 +82,45 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_MOE_BACKEND":
|
||||
lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
|
||||
|
||||
# Set whether to disable recompute the request when the KV cache is full.
|
||||
"FD_DISABLED_RECOVER":
|
||||
lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
|
||||
|
||||
# Set triton kernel JIT compilation directory.
|
||||
"FD_TRITON_KERNEL_CACHE_DIR":
|
||||
lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
|
||||
|
||||
# Whether transition from standalone PD decoupling to centralized inference
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
|
||||
|
||||
# Whether to use fastsafetensor load weight (0 or 1)
|
||||
"FD_USE_FASTSAFETENSOR":
|
||||
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
|
||||
|
||||
# Whether to open Trace.
|
||||
"TRACES_ENABLE":
|
||||
lambda: os.getenv("TRACES_ENABLE", "false"),
|
||||
|
||||
# set traec Server name.
|
||||
"FD_SERVICE_NAME":
|
||||
lambda: os.getenv("FD_SERVICE_NAME", "FastDeploy"),
|
||||
|
||||
# set traec host name.
|
||||
"FD_HOST_NAME":
|
||||
lambda: os.getenv("FD_HOST_NAME", "localhost"),
|
||||
|
||||
# set traec exporter.
|
||||
"TRACES_EXPORTER":
|
||||
lambda: os.getenv("TRACES_EXPORTER", "console"),
|
||||
|
||||
# set traec exporter_otlp_endpoint.
|
||||
"EXPORTER_OTLP_ENDPOINT":
|
||||
lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"),
|
||||
|
||||
# set traec exporter_otlp_headers.
|
||||
"EXPORTER_OTLP_HEADERS":
|
||||
lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -20,13 +20,13 @@ import numpy as np
|
||||
from paddleformers.generation import GenerationConfig
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
|
||||
from fastdeploy.input.text_processor import BaseDataProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class ErnieProcessor(BaseDataProcessor):
|
||||
"""
|
||||
初始化模型实例。
|
||||
@@ -100,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
|
||||
if request.prompt_token_ids is None or len(
|
||||
request.prompt_token_ids) == 0:
|
||||
system = request.get("system")
|
||||
if request.prompt is None and request.messages is None:
|
||||
raise ValueError(
|
||||
f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||
@@ -149,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request['stop_token_ids'] = stop_seqs
|
||||
request['stop_seqs_len'] = stop_seqs_len
|
||||
|
||||
system = request.get("system")
|
||||
# 处理prompt_token_ids
|
||||
if not request.get('prompt_token_ids'):
|
||||
if request.get('prompt') is None and request.get(
|
||||
@@ -160,6 +158,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
if request.get('prompt'):
|
||||
prompt = request.get('prompt')
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
|
||||
tokens = self.tokenizer.tokenize(prompt)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request['prompt_token_ids'] = token_ids
|
||||
@@ -248,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
@@ -283,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||
@@ -442,3 +441,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
data_processor_logger.debug(
|
||||
f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
|
||||
return stop_seqs, stop_seqs_len
|
||||
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
@@ -82,6 +82,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
self.vocab_file = vocab_file
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(vocab_file)
|
||||
# pre-process map-type all spec token for decode accelerate.
|
||||
|
||||
@property
|
||||
def space_token(self):
|
||||
@@ -136,14 +137,19 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
"""doc"""
|
||||
return self.sp_model.id_to_piece(id)
|
||||
|
||||
def spec_init(self):
|
||||
if not hasattr(self, "all_spec_tok"):
|
||||
self.all_spec_tok = set(self.all_special_tokens)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
self.spec_init()
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
# prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if token in self.all_spec_tok:
|
||||
# if not prev_is_special:
|
||||
# out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
@@ -210,13 +216,14 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
# if isinstance(t, AddedToken)
|
||||
# )
|
||||
|
||||
self.spec_init()
|
||||
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
||||
|
||||
# TODO: should this be in the base class?
|
||||
if hasattr(self, "do_lower_case") and self.do_lower_case:
|
||||
# convert non-special tokens to lowercase
|
||||
escaped_special_toks = [
|
||||
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
|
||||
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
|
||||
]
|
||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||
|
@@ -25,6 +25,7 @@ from fastdeploy.utils import data_processor_logger
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class BaseDataProcessor(ABC):
|
||||
"""base class for data processor"""
|
||||
|
||||
@@ -308,6 +309,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
return request
|
||||
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
||||
def process_response(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
@@ -350,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
@@ -385,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||
@@ -425,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
else:
|
||||
return self.process_response_dict_normal(
|
||||
response_dict=response_dict, enable_thinking=enable_thinking)
|
||||
response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
|
||||
def text2ids(self, text, max_model_len, raw_request=True):
|
||||
"""
|
||||
|
198
fastdeploy/metrics/trace_util.py
Normal file
198
fastdeploy/metrics/trace_util.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from opentelemetry.propagate import inject, extract
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from fastapi import FastAPI
|
||||
from fastdeploy.utils import (llm_logger)
|
||||
from fastdeploy import envs
|
||||
import json
|
||||
|
||||
|
||||
# OpenTelemetry Trace context store in metadata
|
||||
TRACE_CARRIER = "trace_carrier"
|
||||
|
||||
traces_enable = False
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
def set_up():
|
||||
try:
|
||||
# when TRACES_ENABLED=true start trace
|
||||
global traces_enable
|
||||
traces_enable = envs.TRACES_ENABLE.lower() == "true"
|
||||
if not traces_enable:
|
||||
llm_logger.warning("Opentelemetry is DISABLED.")
|
||||
return
|
||||
|
||||
llm_logger.info("Opentelemetry is ENABLED, configuring...")
|
||||
# --- read env ---
|
||||
service_name = envs.FD_SERVICE_NAME
|
||||
host_name = envs.FD_HOST_NAME
|
||||
# --- set attributes (Service Name, Host Name, etc.) ---
|
||||
resource_attributes = {
|
||||
"service.name": service_name
|
||||
}
|
||||
if host_name:
|
||||
resource_attributes["host.name"] = host_name
|
||||
|
||||
resource = Resource(attributes=resource_attributes)
|
||||
|
||||
# --- set Exporter ---
|
||||
exporter_type = envs.TRACES_EXPORTER.lower()
|
||||
if exporter_type == "otlp":
|
||||
endpoint = envs.EXPORTER_OTLP_ENDPOINT # should be set
|
||||
headers = envs.EXPORTER_OTLP_HEADERS # e.g., "Authentication=***,k2=v2"
|
||||
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=endpoint,
|
||||
headers=dict(item.split("=") for item in headers.split(",")) if headers else None
|
||||
)
|
||||
processor = BatchSpanProcessor(otlp_exporter)
|
||||
llm_logger.info(f"Using OTLP Exporter, sending to {endpoint} with headers {headers}")
|
||||
else: # default console
|
||||
processor = BatchSpanProcessor(ConsoleSpanExporter())
|
||||
llm_logger.info("Using Console Exporter.")
|
||||
|
||||
# --- set Tracer Provider ---
|
||||
provider = TracerProvider(resource=resource)
|
||||
provider.add_span_processor(processor)
|
||||
trace.set_tracer_provider(provider)
|
||||
global tracer
|
||||
tracer = trace.get_tracer(__name__)
|
||||
except:
|
||||
llm_logger.error("set_up failed")
|
||||
pass
|
||||
|
||||
def instrument(app: FastAPI):
|
||||
try:
|
||||
set_up()
|
||||
if traces_enable:
|
||||
llm_logger.info("Applying instrumentors...")
|
||||
FastAPIInstrumentor.instrument_app(app)
|
||||
except:
|
||||
llm_logger.info("instrument failed")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def inject_to_metadata(request, metadata_attr='metadata'):
|
||||
"""
|
||||
Inject OpenTelemetry trace context into the metadata field of the request.
|
||||
|
||||
Parameters:
|
||||
request: can be a dict or object, with metadata attributes or fields.
|
||||
metadata_attr: the field name of metadata, default is 'metadata'.
|
||||
|
||||
Operation:
|
||||
- If metadata does not exist, create a new one and mount it on the request.
|
||||
- Inject the current trace context as a JSON string and store it in metadata.
|
||||
- Use the key TRACE_CARRIER to store the injected content.
|
||||
|
||||
Note:
|
||||
- This function is a non-blocking operation, and errors are silently ignored.
|
||||
- If there is no metadata attribute in the request, an empty dict will be created for it as its attribute
|
||||
"""
|
||||
try:
|
||||
if request is None or traces_enable == False:
|
||||
return
|
||||
|
||||
metadata = request.get(metadata_attr) if isinstance(request, dict) else getattr(request, metadata_attr, None)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if isinstance(request, dict):
|
||||
request[metadata_attr] = metadata
|
||||
else:
|
||||
setattr(request, metadata_attr, metadata)
|
||||
|
||||
trace_carrier = {}
|
||||
inject(trace_carrier)
|
||||
trace_carrier_json_string = json.dumps(trace_carrier)
|
||||
metadata[TRACE_CARRIER] = trace_carrier_json_string
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def extract_from_metadata(request, metadata_attr='metadata'):
|
||||
"""
|
||||
Extract trace context from metadata of request object (dict or class instance).
|
||||
|
||||
Parameters:
|
||||
request: can be a dictionary or any object, containing metadata attributes or fields.
|
||||
metadata_attr: metadata field name, default is 'metadata'.
|
||||
|
||||
Returns:
|
||||
- Extraction success: returns OpenTelemetry context object (Context)
|
||||
- Extraction failure or exception: returns None
|
||||
"""
|
||||
try:
|
||||
metadata = request.get(metadata_attr) if isinstance(request, dict) else getattr(request, metadata_attr, None)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
trace_carrier_json_string = metadata.get(TRACE_CARRIER)
|
||||
if trace_carrier_json_string is None:
|
||||
return None
|
||||
|
||||
trace_carrier = json.loads(trace_carrier_json_string)
|
||||
ctx = extract(trace_carrier)
|
||||
return ctx
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def extract_from_request(request):
|
||||
"""
|
||||
Extract trace context from trace_carrier of request object (dict or class instance).
|
||||
|
||||
Parameters:
|
||||
request: can be a dictionary or any object, containing metadata attributes or fields.
|
||||
metadata_attr: metadata field name, default is 'metadata'.
|
||||
|
||||
Returns:
|
||||
- Extraction success: returns OpenTelemetry context object (Context)
|
||||
- Extraction failure or exception: returns None
|
||||
"""
|
||||
try:
|
||||
trace_carrier_info = getattr(request, TRACE_CARRIER, None)
|
||||
|
||||
if trace_carrier_info is None:
|
||||
return None
|
||||
|
||||
trace_carrier = json.loads(trace_carrier_info)
|
||||
ctx = extract(trace_carrier)
|
||||
return ctx
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def start_span(span_name, request, kind=trace.SpanKind.CLIENT):
|
||||
"""
|
||||
just start a new span in request trace context
|
||||
"""
|
||||
try:
|
||||
if not traces_enable:
|
||||
return
|
||||
# extract Trace context from request.metadata.trace_carrier
|
||||
ctx = extract_from_metadata(request)
|
||||
with tracer.start_as_current_span(span_name, context=ctx, kind=kind) as span:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def start_span_request(span_name, request, kind=trace.SpanKind.CLIENT):
|
||||
"""
|
||||
just start a new span in request trace context
|
||||
"""
|
||||
try:
|
||||
if not traces_enable:
|
||||
return
|
||||
# extract Trace context from request.metadata.trace_carrier
|
||||
ctx = extract_from_request(request)
|
||||
with tracer.start_as_current_span(span_name, context=ctx, kind=kind) as span:
|
||||
pass
|
||||
except:
|
||||
pass
|
@@ -12,14 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .attention import Attention
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
from .mla_attention_backend import MLAAttentionBackend
|
||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||
from .xpu_attn_backend import XPUAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend"
|
||||
"AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
||||
"MLAAttentionBackend", "FlashAttentionBackend"
|
||||
]
|
||||
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
@@ -187,6 +187,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
|
@@ -111,6 +111,8 @@ class Attention(nn.Layer):
|
||||
k: paddle.Tensor = None,
|
||||
v: paddle.Tensor = None,
|
||||
qkv: paddle.Tensor = None,
|
||||
compressed_kv: paddle.Tensor = None,
|
||||
k_pe: paddle.Tensor = None,
|
||||
forward_meta: ForwardMeta = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -120,12 +122,16 @@ class Attention(nn.Layer):
|
||||
k: the key tensor
|
||||
v: the value tensor
|
||||
forward_meta: the forward meta data
|
||||
compressed_kv: optional compressed key-value cache (for MLA)
|
||||
k_pe: optional key positional encoding (for MLA)
|
||||
"""
|
||||
return forward_meta.attn_backend.forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
self,
|
||||
forward_meta,
|
||||
)
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from functools import cache
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import _Backend, current_platform
|
||||
from fastdeploy.utils import resolve_obj_from_strname
|
||||
|
||||
@@ -40,6 +41,7 @@ def _get_attn_backend(selected_backend: str) -> object:
|
||||
return resolve_obj_from_strname(attention_cls)
|
||||
|
||||
|
||||
def get_attention_backend(selected_backend):
|
||||
"""Selects which attention backend ."""
|
||||
return _get_attn_backend(selected_backend)
|
||||
def get_attention_backend() -> object:
|
||||
"""Selects which attention backend."""
|
||||
attention_backend = envs.FD_ATTENTION_BACKEND
|
||||
return _get_attn_backend(attention_backend)
|
||||
|
@@ -46,6 +46,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
@@ -56,6 +58,8 @@ class AttentionBackend(ABC):
|
||||
k: The key tensor.
|
||||
v: The value tensor.
|
||||
layer: The layer that will be used for the forward.
|
||||
compressed_kv: optional compressed key-value cache (for MLA)
|
||||
k_pe: optional key positional encoding (for MLA)
|
||||
forward_meta: The forward metadata.
|
||||
"""
|
||||
if forward_meta.forward_mode.is_mixed():
|
||||
@@ -64,6 +68,8 @@ class AttentionBackend(ABC):
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
@@ -73,6 +79,8 @@ class AttentionBackend(ABC):
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
@@ -82,6 +90,8 @@ class AttentionBackend(ABC):
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
@@ -92,6 +102,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
@@ -104,6 +116,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
@@ -116,6 +130,8 @@ class AttentionBackend(ABC):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
|
247
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
247
fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
try:
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
except:
|
||||
flash_attention_v3_varlen = None
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
pre_cache_num_blocks_cpu = None
|
||||
kv_token_num_cpu = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
||||
fd_config.parallel_config.expert_parallel_rank
|
||||
if self.device_id is None:
|
||||
self.device_id = device_id
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[device_id]
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.set_max_lengths[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0],
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)
|
||||
res = flash_attention_v3_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.set_max_lengths[0],
|
||||
max_seqlen_k=metadata.set_max_lengths[3],
|
||||
causal=self.causal,
|
||||
)[0].reshape([-1, self.hidden_size])
|
||||
return res
|
@@ -0,0 +1,490 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block, init_signal_layerwise,
|
||||
open_shm_and_get_meta_signal)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
|
||||
multi_head_latent_attention,
|
||||
prefill_mla_write_cache)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
"""
|
||||
"""
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class MLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
MLA Attention Backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int) -> None:
|
||||
"""
|
||||
MLAAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: MLAAttentionMetadata = None
|
||||
|
||||
# 基础配置
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (10000.0
|
||||
if fd_config.model_config.rope_theta is None
|
||||
else fd_config.model_config.rope_theta)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
# For Multi Head Latent Attention
|
||||
self.kv_lora_rank: int = fd_config.model_config.deepseekv3.kv_lora_rank
|
||||
self.qk_rope_head_dim: int = fd_config.model_config.deepseekv3.qk_rope_head_dim
|
||||
self.qk_head_dim: int = fd_config.model_config.deepseekv3.qk_nope_head_dim \
|
||||
+ fd_config.model_config.deepseekv3.qk_rope_head_dim
|
||||
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
|
||||
if fd_config.model_config.deepseekv3.rope_scaling:
|
||||
mscale_all_dim = fd_config.model_config.deepseekv3.rope_scaling.get(
|
||||
"mscale_all_dim", False) # 1.0
|
||||
scaling_factor = fd_config.model_config.deepseekv3.rope_scaling[
|
||||
"factor"] # 40
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
if self.device_id is None:
|
||||
self.device_id = self.rank
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[self.rank]
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
elif metadata._dtype == "float16":
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(self,
|
||||
max_num_blocks: int) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Calculate kv cache shape for MLA
|
||||
"""
|
||||
return (max_num_blocks, 1, self.block_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Prefill阶段的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
# 写入缓存
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
getattr(forward_meta, 'max_input_length', -1),
|
||||
)
|
||||
|
||||
# Flash注意力计算
|
||||
fmha_out = flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Decode阶段的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
# 获取推测解码参数
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||
|
||||
# 写入缓存
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = multi_head_latent_attention(
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.
|
||||
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
None, # cache_k_quant_scales
|
||||
None, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # out_shifts
|
||||
None, # out_smooths
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
"none", # cache_quant_type
|
||||
self.kv_lora_rank,
|
||||
self.max_seq_len,
|
||||
self.attn_softmax_scale,
|
||||
0.0, # quant_max_bound
|
||||
0.0, # quant_min_bound
|
||||
0.0, # out_linear_in_scale
|
||||
speculate_max_tokens,
|
||||
True, # causal
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
return fmha_out
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Mixed模式的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||
|
||||
decode_stage = forward_meta.is_decode_batch
|
||||
prefill_stage = not (forward_meta.is_decode_batch)
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
if prefill_stage:
|
||||
# 写入缓存
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
||||
# FA
|
||||
fmha_out = flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
||||
# Decode
|
||||
if decode_stage:
|
||||
# mla写入缓存
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = multi_head_latent_attention(
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.cum_offsets,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.
|
||||
decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
None, # cache_k_quant_scales
|
||||
None, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # out_shifts
|
||||
None, # out_smooths
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
"none", # cache_quant_type
|
||||
self.kv_lora_rank,
|
||||
self.max_seq_len,
|
||||
self.attn_softmax_scale,
|
||||
0.0, # quant_max_bound
|
||||
0.0, # quant_min_bound
|
||||
0.0, # out_linear_in_scale
|
||||
speculate_max_tokens,
|
||||
True, # causal
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
return fmha_out
|
@@ -17,10 +17,16 @@
|
||||
from .append_attention import append_attention
|
||||
from .get_block_shape_and_split_kv_block import \
|
||||
get_block_shape_and_split_kv_block
|
||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||
from .init_signal_layerwise import init_signal_layerwise
|
||||
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
|
||||
from .pre_cache_len_concat import pre_cache_len_concat
|
||||
|
||||
__all__ = [
|
||||
"get_block_shape_and_split_kv_block", "append_attention",
|
||||
"open_shm_and_get_meta_signal", "init_signal_layerwise"
|
||||
"get_block_shape_and_split_kv_block",
|
||||
"append_attention",
|
||||
"open_shm_and_get_meta_signal",
|
||||
"init_signal_layerwise",
|
||||
"gqa_rope_write_cache",
|
||||
"pre_cache_len_concat",
|
||||
]
|
||||
|
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def gqa_rope_write_cache(
|
||||
qkv: paddle.Tensor,
|
||||
key_cache: paddle.Tensor,
|
||||
value_cache: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
cu_seqlens_k: paddle.Tensor,
|
||||
rotary_embs: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
padding_offsets: paddle.Tensor,
|
||||
cum_offsets: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
kv_batch_ids: paddle.Tensor,
|
||||
kv_tile_ids_per_batch: paddle.Tensor,
|
||||
kv_num_blocks: paddle.Tensor,
|
||||
cache_batch_ids: paddle.Tensor,
|
||||
cache_tile_ids_per_batch: paddle.Tensor,
|
||||
cache_num_blocks: paddle.Tensor,
|
||||
cache_k_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_zp: Optional[paddle.Tensor] = None,
|
||||
cache_v_zp: Optional[paddle.Tensor] = None,
|
||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
cache_quant_type: str = "none"):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||
q, k, v, qkv_ = gqa_rope_write_cache(
|
||||
qkv, key_cache, value_cache, cu_seqlens_q, cu_seqlens_k,
|
||||
rotary_embs, seq_lens_this_time, seq_lens_encoder,
|
||||
seq_lens_decoder, padding_offsets, cum_offsets, block_tables,
|
||||
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks,
|
||||
cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks,
|
||||
cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales,
|
||||
cache_v_dequant_scales, cache_k_zp, cache_v_zp, kv_signal_data,
|
||||
kv_token_num, max_seq_len, cache_quant_type)
|
||||
return q, k, v, qkv_
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License,
|
||||
Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
||||
either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def pre_cache_len_concat(seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
max_dec_len: int = 0,
|
||||
block_size: int = 64):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
|
||||
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time,
|
||||
max_dec_len, block_size)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
@@ -149,6 +149,8 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
|
@@ -41,16 +41,12 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
Create weights for linear layer on XPU
|
||||
"""
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype="float32",
|
||||
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
@@ -28,12 +33,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim=768,
|
||||
params_dtype="bfloat16",
|
||||
fd_config: FDConfig,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int = 768,
|
||||
params_dtype: str = "bfloat16",
|
||||
prefix="",
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the VocabParallelEmbedding layer for the model.
|
||||
|
||||
@@ -41,28 +46,28 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings : vocabulary size.
|
||||
embedding_dim : size of hidden state.
|
||||
params_dtype : data type of parameters.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
num_embeddings (int) : vocabulary size.
|
||||
embedding_dim (int) : size of hidden state.
|
||||
params_dtype (str) : data type of parameters.
|
||||
prefix (str): The name of current layer. Defaults to "".
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
self.mp_rank = hcg.get_model_parallel_rank()
|
||||
self.column_cut = fd_config.parallel_config.column_cut
|
||||
self.world_size = hcg.get_model_parallel_world_size()
|
||||
self.ring_id = hcg.get_model_parallel_group().id
|
||||
self.use_rope = fd_config.model_config.use_rope
|
||||
self.rope_head_dim = fd_config.model_config.rope_head_dim
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.hidden_dropout_prob = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range = fd_config.model_config.initializer_range
|
||||
self.sequence_parallel = fd_config.parallel_config.sequence_parallel
|
||||
self.max_position_embeddings = fd_config.model_config.max_position_embeddings
|
||||
self.freeze_embedding = fd_config.model_config.freeze_embedding
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
self.mp_rank: int = hcg.get_model_parallel_rank()
|
||||
self.column_cut = False
|
||||
self.world_size: int = hcg.get_model_parallel_world_size()
|
||||
self.ring_id: int = hcg.get_model_parallel_group().id
|
||||
self.use_rope: bool = fd_config.model_config.use_rope
|
||||
self.rope_head_dim: int = fd_config.model_config.rope_head_dim
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range: float = fd_config.model_config.initializer_range
|
||||
self.sequence_parallel: bool = fd_config.parallel_config.sequence_parallel
|
||||
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
||||
self.freeze_embedding: bool = fd_config.model_config.freeze_embedding
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.word_embeddings = nn.Embedding(
|
||||
@@ -109,7 +114,8 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.rope_head_dim_shape_tensor = paddle.ones((self.rope_head_dim),
|
||||
dtype="int8")
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -125,7 +131,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
|
||||
def forward(self, ids_remove_padding=None):
|
||||
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
|
@@ -216,6 +216,14 @@ class ReplicatedLinear(LinearBase):
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
self.quant_method.create_weights(self)
|
||||
self.init_weight()
|
||||
|
||||
|
||||
@@ -259,7 +267,10 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_quant=skip_quant)
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks)
|
||||
self.output_size = divide(
|
||||
output_size,
|
||||
self.nranks) # Split the output_size using TP inference.
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
@@ -282,7 +293,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_weight, split_axis=-1)
|
||||
_set_var_distributed(self.linear_weight, split_axis=1)
|
||||
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
@@ -293,7 +304,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_bias, split_axis=-1)
|
||||
_set_var_distributed(self.linear_bias, split_axis=1)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
@@ -318,7 +329,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
activation: str = "gelu",
|
||||
use_fast_ffn: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -333,13 +343,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
activation (str): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool): Whether to use a faster FFN implementation.
|
||||
Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.activation = activation
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
super().__init__(fd_config=fd_config,
|
||||
@@ -374,23 +381,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"gate_proj")
|
||||
bias_tensor = get_tensor(state_dict.pop(gate_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
converted_bias_tensor = paddle.zeros(shape=list(
|
||||
bias_tensor.shape),
|
||||
dtype=bias_tensor.dtype)
|
||||
if not self.use_fast_ffn:
|
||||
converted_bias_tensor = paddle.concat(
|
||||
[bias_tensor[::2], bias_tensor[1::2]], axis=0)
|
||||
else:
|
||||
converted_bias_tensor = bias_tensor
|
||||
state_dict[self.bias_key] = converted_bias_tensor
|
||||
|
||||
if not self.use_fast_ffn:
|
||||
converted_weight_tensor = paddle.concat(
|
||||
[weight_tensor[:, ::2], weight_tensor[:, 1::2]], axis=1)
|
||||
else:
|
||||
converted_weight_tensor = weight_tensor
|
||||
state_dict[self.bias_key] = bias_tensor
|
||||
|
||||
state_dict[self.weight_key] = converted_weight_tensor
|
||||
state_dict[self.weight_key] = weight_tensor
|
||||
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
@@ -413,12 +407,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
self.num_heads = fd_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
input_size = self.embed_dim
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
@@ -448,7 +442,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
weight_tensor = weight_tensor.reshape([
|
||||
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) *
|
||||
(self.head_dim),
|
||||
self.embed_dim,
|
||||
self.hidden_size,
|
||||
])
|
||||
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
|
||||
@@ -513,6 +507,7 @@ class RowParallelLinear(LinearBase):
|
||||
output_size: int = None,
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
reduce_results: bool = True,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -538,10 +533,14 @@ class RowParallelLinear(LinearBase):
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
|
||||
# Split input_size when using TP inference.
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.output_size = output_size
|
||||
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
@@ -551,6 +550,8 @@ class RowParallelLinear(LinearBase):
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
@@ -570,7 +571,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
shape=[self.embed_dim],
|
||||
shape=[self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
@@ -589,7 +590,7 @@ class RowParallelLinear(LinearBase):
|
||||
else:
|
||||
out = paddle.matmul(x, self.linear_weight)
|
||||
|
||||
if self.nranks > 1:
|
||||
if self.reduce_results and self.nranks > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
@@ -28,12 +33,12 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
prefix="",
|
||||
with_bias=False,
|
||||
):
|
||||
fd_config: FDConfig,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
prefix: str = "",
|
||||
with_bias: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parallelized LMhead.
|
||||
|
||||
@@ -43,21 +48,22 @@ class ParallelLMHead(nn.Layer):
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
prefix (str): full name of the layer in the state dict
|
||||
prefix (str): The name of current layer. Defaults to "".
|
||||
with_bias (bool): whether to have bias. Default: False.
|
||||
"""
|
||||
super(ParallelLMHead, self).__init__()
|
||||
self.linear_weight_key = prefix + ".weight"
|
||||
self.linear_weight_key: str = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key = prefix + ".bias"
|
||||
self.linear_bias_key: Optional[str] = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.linear_bias_key: Optional[str] = None
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
@@ -92,7 +98,8 @@ class ParallelLMHead(nn.Layer):
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -122,7 +129,7 @@ class ParallelLMHead(nn.Layer):
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
|
@@ -22,14 +22,35 @@ from paddleformers.utils.log import logger
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from ..utils import get_tensor, create_and_set_parameter
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import create_and_set_parameter, get_tensor
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
||||
moe_expert_reduce, noaux_tc)
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
def get_moe_scores(gating_output: paddle.Tensor, n_group, topk_group, top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class CutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
@@ -199,23 +220,47 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out = get_moe_scores(gate_out, layer.n_group,
|
||||
layer.topk_group, layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
None, # Use layer.gate_correction_bias in get_moe_scores.
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=True,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if self.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
@@ -234,11 +279,11 @@ class CutlassMoEMethod(MoEMethodBase):
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
norm_topk_prob=True,
|
||||
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
@@ -195,8 +195,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
hidden_size = layer.hidden_size
|
||||
num_experts = layer.num_experts
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
|
@@ -17,6 +17,7 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
||||
@@ -25,17 +26,24 @@ from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_method=None):
|
||||
def __init__(self, quant_config=None):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
@@ -52,7 +60,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert layer.quant_method.quant_config.name() == "wint8"
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert algo == "wint8"
|
||||
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
@@ -63,9 +75,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
|
||||
if self.quant_config.name() == "wint8":
|
||||
if algo == "wint8":
|
||||
max_bound = 127
|
||||
elif self.quant_config.name() == "wint4":
|
||||
elif algo == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
@@ -111,15 +123,13 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(scores,
|
||||
k=top_k,
|
||||
axis=-1,
|
||||
sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -139,14 +149,12 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
@@ -158,10 +166,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
moe_intermediate_size * 2,
|
||||
hidden_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
@@ -193,8 +201,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
@@ -205,10 +214,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
hidden_size,
|
||||
moe_intermediate_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
@@ -324,7 +333,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(scores,
|
||||
@@ -352,13 +360,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
adamard_matrix = create_hadamard_matrix_map[hidden_size]
|
||||
x = paddle.matmul(x.cast("float32"), adamard_matrix)
|
||||
@@ -371,8 +379,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
permute_x = permute_x / quant_activation_scale
|
||||
permute_x = permute_x.astype("float8_e4m3fn")
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
permute_x,
|
||||
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
|
||||
@@ -383,10 +389,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
moe_intermediate_size * 2,
|
||||
hidden_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
@@ -426,8 +432,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
intermediate_cache2 = intermediate_cache2 / quant_activation_scale
|
||||
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn")
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
@@ -439,10 +446,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
hidden_size,
|
||||
moe_intermediate_size,
|
||||
max_num_tokens_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
|
@@ -224,6 +224,7 @@ class TritonWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
|
@@ -30,10 +30,15 @@ class FusedMoE(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
reduce_results: bool = True,
|
||||
moe_intermediate_size: int = -1,
|
||||
num_experts: int = -1,
|
||||
expert_id_offset: int = 0,
|
||||
top_k: int = -1,
|
||||
topk_method: str = "",
|
||||
topk_group: int = -1,
|
||||
n_group: int = -1,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
layer_idx: int = -1,
|
||||
moe_tag: str = "",
|
||||
weight_key_map: dict = {},
|
||||
@@ -49,6 +54,7 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.layer_idx = layer_idx
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_degree
|
||||
@@ -60,28 +66,33 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.moe_config = fd_config.moe_config
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = self.num_experts // self.ep_size
|
||||
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
|
||||
self.top_k = top_k
|
||||
self.hidden_size = self.hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
self.weight_key_map = weight_key_map
|
||||
|
||||
self.use_method = envs.FD_MOE_BACKEND.lower()
|
||||
self.gate_correction_bias = None
|
||||
self.moe_tag = moe_tag
|
||||
|
||||
if self.ep_size > 1:
|
||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||
|
||||
self.expert_id_offset = expert_id_offset
|
||||
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
# used for deepseek_v3
|
||||
self.topk_method = topk_method
|
||||
self.topk_group = topk_group
|
||||
self.n_group = n_group
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
moe_quant_config = fd_config.quant_config
|
||||
self.moe_quant_type = None
|
||||
if moe_quant_config:
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
@@ -90,12 +101,78 @@ class FusedMoE(nn.Layer):
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# It's for RL to build model
|
||||
self.init_moe_weights()
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
|
||||
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
|
||||
, ep_size={self.ep_size}, \
|
||||
tp_size={self.tp_size}.")
|
||||
|
||||
def init_moe_weights(self):
|
||||
"""
|
||||
Initialize the weight shapes and parameters for the MoE layer.
|
||||
Combines weight shape initialization and parameter creation into a single function.
|
||||
"""
|
||||
# Initialize weight shapes
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
gate_weight_shape = [self.hidden_size, self.num_experts]
|
||||
gate_correction_bias_shape = [1, self.num_experts]
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
if self.moe_config.moe_use_aux_free:
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
ffn1_output_dim = self.moe_intermediate_size * 2
|
||||
if self.moe_quant_type in ["fp8", "wint8"]:
|
||||
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
|
||||
else:
|
||||
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
|
||||
|
||||
# Create parameters
|
||||
if self.moe_quant_type == "fp8":
|
||||
#(TODO:gaoziyuan)
|
||||
pass
|
||||
elif self.moe_quant_type == "wint8":
|
||||
self.weight_dtype = "int8"
|
||||
self.init_weight_only_scale()
|
||||
|
||||
# FFN1 parameters
|
||||
self.moe_ffn1_weight = self.create_parameter(
|
||||
shape=ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
# FFN2 parameters
|
||||
self.moe_ffn2_weight = self.create_parameter(
|
||||
shape=ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def init_weight_only_scale(self):
|
||||
"""
|
||||
Initialize the weight scale.
|
||||
"""
|
||||
self.moe_ffn1_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.moe_ffn2_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
def load_experts_weight(self, state_dict: dict,
|
||||
ffn1_expert_weight_key: str,
|
||||
ffn2_expert_weight_key: str):
|
||||
|
@@ -16,9 +16,10 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2
|
||||
|
||||
|
||||
@triton.jit
|
||||
@paddle_use_triton_v2()
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
@@ -31,22 +32,22 @@ def fused_moe_kernel_paddle(
|
||||
num_tokens_post_padded_ptr,
|
||||
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
stride_am: tl.constexpr,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_be: tl.constexpr,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
stride_cm: tl.constexpr,
|
||||
stride_cn: tl.constexpr,
|
||||
stride_asm: tl.constexpr,
|
||||
stride_ask: tl.constexpr,
|
||||
stride_bse: tl.constexpr,
|
||||
stride_bsk: tl.constexpr,
|
||||
stride_bsn: tl.constexpr,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -87,7 +88,7 @@ def fused_moe_kernel_paddle(
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
|
133
fastdeploy/model_executor/layers/mtp_linear.py
Normal file
133
fastdeploy/model_executor/layers/mtp_linear.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
class ParallelEHProjection(nn.Layer):
|
||||
"""
|
||||
"Parallelized Embedding Hidden States Projection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
prefix="",
|
||||
with_bias=False,
|
||||
):
|
||||
"""
|
||||
Parallelized Embedding Hidden States Projection.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
prefix (str): full name of the layer in the state dict
|
||||
"""
|
||||
super(ParallelEHProjection, self).__init__()
|
||||
self.linear_weight_key = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.out_linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
self.out_linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(
|
||||
state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
if self.out_linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.out_linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.linear_bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
Args:
|
||||
input (Tensor): The input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.out_linear(logits)
|
||||
return logits
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
@@ -28,16 +33,16 @@ class RMSNorm(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
prefix="",
|
||||
linear_bias=None,
|
||||
quant_scale=None,
|
||||
begin_norm_axis=1,
|
||||
):
|
||||
fd_config: FDConfig,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-5,
|
||||
prefix: str = "",
|
||||
linear_bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
begin_norm_axis: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the normalization layer.
|
||||
Initializes the RMSNormalization layer.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
@@ -45,33 +50,33 @@ class RMSNorm(nn.Layer):
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
hidden_size (int) : size of hidden state.
|
||||
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
|
||||
weight_key (str): Key name of weight in the pdparams state dict. Defaults to None, means no weight.
|
||||
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
|
||||
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
prefix(str,optional):The name of current layer. Defaults to "".
|
||||
linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
|
||||
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the specified norm_type is not supported.
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
self.prefix = prefix
|
||||
self.hidden_size = hidden_size
|
||||
self.prefix: str = prefix
|
||||
self.hidden_size: int = hidden_size
|
||||
if len(prefix) == 0:
|
||||
self.weight_key = None
|
||||
self.weight_key: Optional[str] = None
|
||||
else:
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.with_weight = self.weight_key is not None
|
||||
self.eps = eps
|
||||
self.norm_func = fused_rms_norm
|
||||
self.linear_bias = linear_bias
|
||||
self.quant_scale = quant_scale
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype = self._dtype
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.eps: float = eps
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = self._dtype
|
||||
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.begin_norm_axis: int = begin_norm_axis
|
||||
|
||||
self.init_weight()
|
||||
|
||||
@@ -88,7 +93,8 @@ class RMSNorm(nn.Layer):
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -102,7 +108,10 @@ class RMSNorm(nn.Layer):
|
||||
self._norm_weight_dtype)
|
||||
self.ln_weight.set_value(weight_tensor)
|
||||
|
||||
def forward(self, x, residual_input=None):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
@@ -140,18 +149,18 @@ class RMSNorm(nn.Layer):
|
||||
|
||||
class LayerNorm(nn.Layer):
|
||||
"""
|
||||
Normalization layer.
|
||||
Initializes the LayerNormalization layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
fd_config: FDConfig,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-5,
|
||||
prefix="",
|
||||
linear_bias=None,
|
||||
quant_scale=None,
|
||||
with_bias=False,
|
||||
linear_bias: paddle.Tensor = None,
|
||||
quant_scale: float = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the normalization layer.
|
||||
@@ -160,35 +169,37 @@ class LayerNorm(nn.Layer):
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
hidden_size (int) : size of hidden state.
|
||||
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
|
||||
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
|
||||
with_bias (bool):Whether to include bias or not. Defaults to False.
|
||||
Raises:
|
||||
NotImplementedError: If the specified norm_type is not supported.
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
self.prefix = prefix
|
||||
self.hidden_size = hidden_size
|
||||
self.prefix: str = prefix
|
||||
self.hidden_size: int = hidden_size
|
||||
if len(prefix) == 0:
|
||||
self.weight_key = None
|
||||
self.weight_key: Optional[str] = None
|
||||
else:
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.with_weight = self.weight_key is not None
|
||||
self.bias_key = f"{prefix}.bias"
|
||||
self.with_bias = with_bias
|
||||
self.eps = eps
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.bias_key: str = f"{prefix}.bias"
|
||||
self.with_bias: bool = with_bias
|
||||
self.eps: float = eps
|
||||
self.quant_scale: float = quant_scale
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
|
||||
self.norm_func = fused_layer_norm
|
||||
self.linear_bias = linear_bias
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype = "float32"
|
||||
|
||||
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
|
||||
self.init_weight()
|
||||
|
||||
@@ -212,7 +223,8 @@ class LayerNorm(nn.Layer):
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -233,7 +245,10 @@ class LayerNorm(nn.Layer):
|
||||
self._norm_weight_dtype)
|
||||
self.ln_bias.set_value(bias_tensor)
|
||||
|
||||
def forward(self, x, residual_input=None):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
@@ -259,7 +274,7 @@ class LayerNorm(nn.Layer):
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
|
@@ -15,8 +15,9 @@
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..attention import Attention
|
||||
from ..moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
@@ -132,18 +132,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
if self.quant_config.name() == "wint4":
|
||||
linear_weight_scale_shape[0] *= 2
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
@@ -195,6 +191,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||
weight,
|
||||
algo=self.quant_config.algo,
|
||||
|
@@ -14,13 +14,18 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
|
||||
|
||||
from .utils import CpuGuard
|
||||
|
||||
|
||||
@@ -99,20 +104,164 @@ class QwenRotaryEmbedding:
|
||||
return rot_emb
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
"""
|
||||
"""
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
"""
|
||||
"""
|
||||
return (dim * math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 *
|
||||
math.log(base))
|
||||
|
||||
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
"""
|
||||
"""
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min, max, dim):
|
||||
"""
|
||||
"""
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (paddle.arange(dim, dtype=paddle.float32) - min) / (max -
|
||||
min)
|
||||
ramp_func = paddle.clip(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
class DeepseekScalingRotaryEmbedding(nn.Layer):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
|
||||
Args:
|
||||
rotary_dim(int): Dimension of rotary embeddings (head dimension)
|
||||
max_position_embeddings(int): Original training context length
|
||||
base(float): Base value used to compute the inverse frequencies.
|
||||
scaling_factor(float): Context extension scaling ratio (target_len / original_len)
|
||||
extrapolation_factor(float): Weight for extrapolated frequencies (default=1)
|
||||
attn_factor(float): Attention magnitude scaling factor (default=1)
|
||||
beta_fast(int): High-frequency correction cutoff (default=32)
|
||||
beta_slow(int): Low-frequency correction cutoff (default=1)
|
||||
mscale(float): Primary magnitude scaling factor (default=1)
|
||||
mscale_all_dim(float): Alternate magnitude scaling factor (default=0)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
scaling_factor: float,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
|
||||
self.cos_sin_cache: paddle.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistable=True)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
|
||||
pos_freqs = self.base**(
|
||||
paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) /
|
||||
self.rotary_dim)
|
||||
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> paddle.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = paddle.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
dtype=paddle.float32)
|
||||
freqs = paddle.einsum("i,j->ij", t, inv_freq)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
cache = paddle.concat((cos, sin), axis=-1)
|
||||
return cache.cast(self._dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: paddle.Tensor,
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
"""
|
||||
# In-place operations that update the query and key tensors.
|
||||
fused_rotary_position_encoding(query, key, position_ids,
|
||||
self.cos_sin_cache, self.rotary_dim,
|
||||
False)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
def get_rope_impl(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
position_ids: paddle.Tensor,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
The real implementation of get_rope
|
||||
"""
|
||||
|
||||
architecture = model_config.architectures[0]
|
||||
if model_config is not None and model_config is None or architecture.startswith(
|
||||
"Qwen"):
|
||||
if model_config is None or architecture.startswith("Qwen"):
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
@@ -126,10 +275,10 @@ def get_rope_impl(
|
||||
def get_rope_xpu(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: ModelConfig,
|
||||
position_ids: paddle.Tensor,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
In XPU, cos and sin compute must be done on cpu
|
||||
"""
|
||||
@@ -143,12 +292,27 @@ def get_rope_xpu(
|
||||
def get_rope(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: ModelConfig,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
position_ids: paddle.Tensor,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor: int = 1,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
The warpper of get_rope
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
|
||||
Args:
|
||||
rotary_dim (int):
|
||||
Dimension of rotary embeddings (head dimension)
|
||||
base (float, optional):
|
||||
Base value used to compute the inverse frequencies.
|
||||
Default: 10000.0.
|
||||
position_ids (paddle.Tensor):
|
||||
Tensor containing position indices of input tokens.
|
||||
model_config (Optional[ModelConfig]):
|
||||
Model configuration object containing architecture information.
|
||||
If provided, determines RoPE implementation based on model architecture.
|
||||
partial_rotary_factor (int, optional):
|
||||
Factor controlling partial rotary application.
|
||||
Default: 1 (apply to all dimensions).
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return get_rope_xpu(rotary_dim, base, position_ids, model_config,
|
||||
@@ -255,7 +419,24 @@ def get_rope_3d(
|
||||
paritial_rotary_factor: 1,
|
||||
max_position: 131072,
|
||||
freq_allocation: 2,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
|
||||
Args:
|
||||
rotary_dim (int):
|
||||
Dimension of rotary embeddings (head dimension)
|
||||
base (float, optional):
|
||||
Base value used to compute the inverse frequencies.
|
||||
Default: 10000.0.
|
||||
position_ids (paddle.Tensor):
|
||||
Tensor containing position indices of input tokens.
|
||||
partial_rotary_factor (int, optional):
|
||||
Factor controlling partial rotary application.
|
||||
Default: 1 (apply to all dimensions).
|
||||
max_position: Maximum position index to precompute.
|
||||
freq_allocation: Number of rotary dimensions allocated to temporal axis
|
||||
"""
|
||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
|
||||
paritial_rotary_factor,
|
||||
max_position,
|
||||
|
@@ -42,3 +42,4 @@ class SamplingMetadata:
|
||||
|
||||
top_p: paddle.Tensor
|
||||
top_k: Optional[paddle.Tensor] = None
|
||||
max_num_logprobs: Optional[int] = None
|
||||
|
@@ -16,10 +16,10 @@
|
||||
|
||||
from .apply_penalty_multi_scores import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores)
|
||||
from .top_p_sampling import top_p_sampling
|
||||
from .top_k_top_p_sampling import top_k_top_p_sampling
|
||||
|
||||
__all__ = [
|
||||
"apply_penalty_multi_scores",
|
||||
"apply_speculative_penalty_multi_scores",
|
||||
"top_p_sampling",
|
||||
"top_k_top_p_sampling",
|
||||
]
|
||||
|
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
|
||||
def top_k_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16.
|
||||
top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to specify the top_p corresponding to each query.
|
||||
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the top_k corresponding to each query.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to avoid sampling low score tokens.
|
||||
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the random seed for each query.
|
||||
seed(int, optional): the random seed. Default is -1,
|
||||
k(int): the number of top_k scores/ids to be returned. Default is 0.
|
||||
Only used when FD_SAMPLING_CLASS is `air`.
|
||||
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
|
||||
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
|
||||
Only used when FD_SAMPLING_CLASS is `air` or `base`.
|
||||
order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`.
|
||||
If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
|
||||
"""
|
||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x,
|
||||
top_p,
|
||||
threshold,
|
||||
topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
# rejection
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
|
||||
_ = None
|
||||
# base non-truncated
|
||||
elif top_p_class == "base_non_truncated":
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="non-truncated")
|
||||
# base truncated
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="truncated")
|
||||
return _, ids
|
||||
|
||||
|
||||
def air_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
air_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
||||
out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k,
|
||||
mode)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
||||
return out, ids
|
||||
|
||||
|
||||
def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: paddle.Tensor,
|
||||
seed: int = -1,
|
||||
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
rejection_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
rejection_top_p_sampling, top_k_renorm_probs)
|
||||
|
||||
if paddle.count_nonzero(top_k) == 0:
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
None,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
if order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(x, top_k)
|
||||
ids = rejection_top_p_sampling(
|
||||
renorm_probs,
|
||||
top_p,
|
||||
None,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
top_k,
|
||||
seed,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||
return ids
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user