mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-30 22:32:30 +08:00
Compare commits
375 Commits
v2.0.0
...
release/2.
Author | SHA1 | Date | |
---|---|---|---|
![]() |
fa5a07b8fc | ||
![]() |
2ee91d7a96 | ||
![]() |
187ccb0f04 | ||
![]() |
98b3647aad | ||
![]() |
ffec66097c | ||
![]() |
c2f5c99b1e | ||
![]() |
cc5430e4c2 | ||
![]() |
1e19833ba5 | ||
![]() |
4da603daec | ||
![]() |
c49c43d51c | ||
![]() |
a424ab907f | ||
![]() |
10a95f8ed5 | ||
![]() |
b9af800edd | ||
![]() |
64cf769bee | ||
![]() |
3364af767b | ||
![]() |
578b8c5da2 | ||
![]() |
8517e04956 | ||
![]() |
aad9d3564e | ||
![]() |
6039cdc2c5 | ||
![]() |
6545994c58 | ||
![]() |
6a90cfd144 | ||
![]() |
47e6270dec | ||
![]() |
80db7fce05 | ||
![]() |
96aed92e4a | ||
![]() |
d8444e22ca | ||
![]() |
df27a488b1 | ||
![]() |
b1f8f1aa07 | ||
![]() |
4e369c7fa7 | ||
![]() |
f8d3255520 | ||
![]() |
e8af92aab7 | ||
![]() |
8b9f167ccc | ||
![]() |
93d999b830 | ||
![]() |
4d6fb96cd6 | ||
![]() |
c18975366e | ||
![]() |
4a9c04a746 | ||
![]() |
d97aab25bc | ||
![]() |
1b399b91c0 | ||
![]() |
8bf48dfab8 | ||
![]() |
fcdc5c2c54 | ||
![]() |
5d4d38674f | ||
![]() |
d07338f932 | ||
![]() |
3ffbc98179 | ||
![]() |
edd13aad66 | ||
![]() |
1065406ed3 | ||
![]() |
570ad54b51 | ||
![]() |
9af57513b3 | ||
![]() |
2e6d97f5eb | ||
![]() |
ff030d9090 | ||
![]() |
5a829fc7af | ||
![]() |
d998efbc17 | ||
![]() |
8a15bdc0c8 | ||
![]() |
ad8ea68906 | ||
![]() |
101605869c | ||
![]() |
28918702c2 | ||
![]() |
02596fc537 | ||
![]() |
03347626a6 | ||
![]() |
b2df0311b8 | ||
![]() |
d1d321bafd | ||
![]() |
dc5d3ff5a0 | ||
![]() |
f0a707e06f | ||
![]() |
4870919682 | ||
![]() |
a375378cc1 | ||
![]() |
192f9caab4 | ||
![]() |
81092c0fe3 | ||
![]() |
ad816f20f4 | ||
![]() |
37b76158f9 | ||
![]() |
fe2094609f | ||
![]() |
b4bb54b56b | ||
![]() |
eeec4bd15e | ||
![]() |
d2592750f7 | ||
![]() |
25f51b0611 | ||
![]() |
9b07f85f6d | ||
![]() |
2fe31c6f0f | ||
![]() |
a33e557732 | ||
![]() |
054c790642 | ||
![]() |
ca4e4ab911 | ||
![]() |
c000cff744 | ||
![]() |
86ff68be4b | ||
![]() |
702c313ed1 | ||
![]() |
6706ccb37e | ||
![]() |
1b6f482c15 | ||
![]() |
5d3bf308f6 | ||
![]() |
f672a34f95 | ||
![]() |
bc0b92bba4 | ||
![]() |
3dd8492601 | ||
![]() |
bd77a3a643 | ||
![]() |
9561603ed9 | ||
![]() |
e26313a355 | ||
![]() |
4367c09a5f | ||
![]() |
8e789dcb67 | ||
![]() |
5f6fc7f7b9 | ||
![]() |
d4059cabf0 | ||
![]() |
c8dd5976ae | ||
![]() |
4880c16be3 | ||
![]() |
dade19d7a4 | ||
![]() |
fe17410f9c | ||
![]() |
1a543bca29 | ||
![]() |
5f56d289a7 | ||
![]() |
25005fee30 | ||
![]() |
22cab724e8 | ||
![]() |
32307283f1 | ||
![]() |
583eae2fd1 | ||
![]() |
1ef38b1563 | ||
![]() |
4498058722 | ||
![]() |
66304cf921 | ||
![]() |
5b9aec1f10 | ||
![]() |
66c3835a46 | ||
![]() |
d850660872 | ||
![]() |
998968f1e8 | ||
![]() |
fe0e3f508b | ||
![]() |
0616c208d2 | ||
![]() |
7dfdd157ac | ||
![]() |
d17886de19 | ||
![]() |
bd29b2aaca | ||
![]() |
6ead7a3a49 | ||
![]() |
e4ba9a0dde | ||
![]() |
3f8a41e68c | ||
![]() |
b242150f94 | ||
![]() |
db698bda01 | ||
![]() |
28fff1b035 | ||
![]() |
acc5c0aa85 | ||
![]() |
d89b6dd43f | ||
![]() |
8e203666d9 | ||
![]() |
5acde4eb43 | ||
![]() |
ffa0f4d99b | ||
![]() |
ecf2fd5b9a | ||
![]() |
eeadbf332a | ||
![]() |
327e1943fa | ||
![]() |
35935da9e5 | ||
![]() |
159767717d | ||
![]() |
4dc130c5a9 | ||
![]() |
99a70fc722 | ||
![]() |
5ca684c762 | ||
![]() |
74aa31d15b | ||
![]() |
9c962343f2 | ||
![]() |
ad7bb52a28 | ||
![]() |
73cfe1fd37 | ||
![]() |
b2f9a42d87 | ||
![]() |
3214fb5393 | ||
![]() |
be0a0f2bb2 | ||
![]() |
502ee92a0a | ||
![]() |
907d561523 | ||
![]() |
dafe02a7b9 | ||
![]() |
1a815b7a2a | ||
![]() |
f2a528f9ae | ||
![]() |
286802a070 | ||
![]() |
7d87aaace8 | ||
![]() |
e80ea8a71b | ||
![]() |
b1d787a272 | ||
![]() |
c8bf8b3913 | ||
![]() |
83048bbe55 | ||
![]() |
ec52d39e68 | ||
![]() |
bddf403576 | ||
![]() |
776fb03250 | ||
![]() |
60311956e4 | ||
![]() |
238766e403 | ||
![]() |
01485cd28b | ||
![]() |
dd877f38b1 | ||
![]() |
247010d298 | ||
![]() |
6ccc10ad47 | ||
![]() |
8f426c1690 | ||
![]() |
fb410b5f4c | ||
![]() |
1d29dd80f7 | ||
![]() |
69996a40da | ||
![]() |
0700c90caa | ||
![]() |
332154f504 | ||
![]() |
4b02b96467 | ||
![]() |
8c167e130c | ||
![]() |
7634ffb709 | ||
![]() |
6ce3a8a497 | ||
![]() |
2970b00dfa | ||
![]() |
f37d00e856 | ||
![]() |
c40df1802e | ||
![]() |
980126b83a | ||
![]() |
0fb37ab7e4 | ||
![]() |
5151bc92c8 | ||
![]() |
f935d6f862 | ||
![]() |
3792345c3a | ||
![]() |
e14587a954 | ||
![]() |
87a2f4191d | ||
![]() |
2c0ff068e2 | ||
![]() |
e3a843f2c5 | ||
![]() |
6235ef3881 | ||
![]() |
29c3292f02 | ||
![]() |
832d25334a | ||
![]() |
bfeb664ab8 | ||
![]() |
85a78d695d | ||
![]() |
ca0f71bd39 | ||
![]() |
172e69fe17 | ||
![]() |
1272c7ce98 | ||
![]() |
850c9d98d4 | ||
![]() |
a39a67334c | ||
![]() |
6c4cfd9359 | ||
![]() |
9b22b8d2c3 | ||
![]() |
5b59a97030 | ||
![]() |
475dc6d84e | ||
![]() |
ad202272ed | ||
![]() |
e51f018577 | ||
![]() |
95b5af24db | ||
![]() |
7c5e34e72d | ||
![]() |
dbe6225b33 | ||
![]() |
9b84d51e25 | ||
![]() |
93bb68aa71 | ||
![]() |
dc67c10a7e | ||
![]() |
920e6b3f60 | ||
![]() |
89a485b69f | ||
![]() |
48e6a0ca26 | ||
![]() |
e991777757 | ||
![]() |
2a8a2c06de | ||
![]() |
2c6a9e887e | ||
![]() |
0eedbdaee0 | ||
![]() |
8020927f50 | ||
![]() |
56102e91e1 | ||
![]() |
0262ef7eb3 | ||
![]() |
ff4569f135 | ||
![]() |
8a619e9db5 | ||
![]() |
2845bde964 | ||
![]() |
2f74e93d7e | ||
![]() |
67990e0572 | ||
![]() |
95a214ae43 | ||
![]() |
bce2c6cd7c | ||
![]() |
cc4cec0a74 | ||
![]() |
17c5d3a241 | ||
![]() |
8c5407d9e4 | ||
![]() |
25698d56d1 | ||
![]() |
b8676d71a8 | ||
![]() |
43976138de | ||
![]() |
e546e6b1b0 | ||
![]() |
9c8292fb19 | ||
![]() |
a5e95013b5 | ||
![]() |
93481a5478 | ||
![]() |
eb77b1be6d | ||
![]() |
5328daa333 | ||
![]() |
a42fc3f40b | ||
![]() |
fbe3547c95 | ||
![]() |
6efad14b95 | ||
![]() |
d306944f4f | ||
![]() |
e81137e581 | ||
![]() |
cd52dc0f65 | ||
![]() |
1339e56282 | ||
![]() |
0eb5dc18d3 | ||
![]() |
e679567d59 | ||
![]() |
bbe2c5c968 | ||
![]() |
4b14dca1d6 | ||
![]() |
c8c280c4d3 | ||
![]() |
ddb10ac509 | ||
![]() |
d49f8fb30a | ||
![]() |
67180c1ff9 | ||
![]() |
273efba76f | ||
![]() |
1cfba5ba3e | ||
![]() |
31cab9f87b | ||
![]() |
d3dfa1446c | ||
![]() |
b630031414 | ||
![]() |
f50c25178b | ||
![]() |
dbb9e2506b | ||
![]() |
1f15ca21e4 | ||
![]() |
7dfd2ea052 | ||
![]() |
42d4001400 | ||
![]() |
52aca233e8 | ||
![]() |
9c25dcca0b | ||
![]() |
d245d1ca6c | ||
![]() |
63d6e7ce06 | ||
![]() |
aa76085d1f | ||
![]() |
42b80182e0 | ||
![]() |
dda4a9f848 | ||
![]() |
a83a3eea5f | ||
![]() |
0d0340392f | ||
![]() |
0253381fb9 | ||
![]() |
2d1184aefe | ||
![]() |
17314ee126 | ||
![]() |
101ad33332 | ||
![]() |
0fad10b35a | ||
![]() |
61b3997b85 | ||
![]() |
e7bcbbab52 | ||
![]() |
5fc659b900 | ||
![]() |
33db137d0b | ||
![]() |
9d6a42b334 | ||
![]() |
1b712bba82 | ||
![]() |
fd91da7b41 | ||
![]() |
15c8c240b5 | ||
![]() |
7cdd8d290d | ||
![]() |
4c7b8bc458 | ||
![]() |
2e81792d64 | ||
![]() |
b7858c22d9 | ||
![]() |
09bbac6de0 | ||
![]() |
7f64d408a9 | ||
![]() |
ece88596ed | ||
![]() |
bad53c6b6e | ||
![]() |
16940822a7 | ||
![]() |
d48c03413f | ||
![]() |
e9e8443ea8 | ||
![]() |
749b2e9c89 | ||
![]() |
f6ad26fc08 | ||
![]() |
c08561c13a | ||
![]() |
2c3607407f | ||
![]() |
b5e4288704 | ||
![]() |
abbbd0cddc | ||
![]() |
e98937cbba | ||
![]() |
240d6236bc | ||
![]() |
59071268b6 | ||
![]() |
8c660a0dfb | ||
![]() |
ce5adec877 | ||
![]() |
36571fd2d9 | ||
![]() |
830de5a925 | ||
![]() |
d33105baeb | ||
![]() |
24f934f1f9 | ||
![]() |
1e2319cbef | ||
![]() |
e45050cae3 | ||
![]() |
b0f525955c | ||
![]() |
2ea267f624 | ||
![]() |
1d8af7ab73 | ||
![]() |
54affdc44b | ||
![]() |
a4fdb3970b | ||
![]() |
2a86928657 | ||
![]() |
b1c53fa779 | ||
![]() |
da20cf681e | ||
![]() |
4ccd1696ab | ||
![]() |
888780ffde | ||
![]() |
e3768c5a83 | ||
![]() |
1f28bdf994 | ||
![]() |
03a74995b8 | ||
![]() |
b89180f1cd | ||
![]() |
be21ef5047 | ||
![]() |
771e71a24d | ||
![]() |
0350831c2b | ||
![]() |
fee544e808 | ||
![]() |
c4718fd693 | ||
![]() |
f7cad30a38 | ||
![]() |
6b10c19482 | ||
![]() |
f4f1d8de44 | ||
![]() |
6610aa29d0 | ||
![]() |
f72c4de539 | ||
![]() |
f6ffbc3cbd | ||
![]() |
e8bbe7244b | ||
![]() |
57b086dc6b | ||
![]() |
525be243e7 | ||
![]() |
d0f4d6ba3a | ||
![]() |
26d5d737dd | ||
![]() |
fefbd65cf8 | ||
![]() |
1eb8ea7328 | ||
![]() |
ef6649a577 | ||
![]() |
1b54a2831e | ||
![]() |
2579e8fea8 | ||
![]() |
91528f1af9 | ||
![]() |
4e293e50fa | ||
![]() |
66b321d9ec | ||
![]() |
68b4755587 | ||
![]() |
04a8e1ef2b | ||
![]() |
a6e9161045 | ||
![]() |
90ef28d982 | ||
![]() |
b37585e693 | ||
![]() |
9cb08e71e8 | ||
![]() |
dacc46f04c | ||
![]() |
09ded7715f | ||
![]() |
11cfdf5d89 | ||
![]() |
e7fa57ebae | ||
![]() |
a5ae88ded9 | ||
![]() |
87e638498c | ||
![]() |
667547be59 | ||
![]() |
b38823bc66 | ||
![]() |
050d9658a5 | ||
![]() |
be5cabaf80 | ||
![]() |
240bdac2a4 | ||
![]() |
00863c43fd | ||
![]() |
3d3bccdf79 | ||
![]() |
9fd74f75bd | ||
![]() |
05c670e593 | ||
![]() |
d222248d00 | ||
![]() |
e5b94d4117 | ||
![]() |
87e2e58a22 | ||
![]() |
de20e5a992 | ||
![]() |
2f9c0618f0 | ||
![]() |
9a14ab6572 | ||
![]() |
d1cb3ed571 | ||
![]() |
b8a8a19689 |
7
.flake8
Normal file
7
.flake8
Normal file
@@ -0,0 +1,7 @@
|
||||
[flake8]
|
||||
ignore = E203, E402, E501, E731, E741, W503, W605, E722, E231, W604, E702, E226, E221, E713, E271
|
||||
max-line-length = 119
|
||||
|
||||
# E402: module level import not at top of file
|
||||
per-file-ignores =
|
||||
__init__.py:F401,F403,E402
|
50
.github/workflows/Codestyle-Check.yml
vendored
Normal file
50
.github/workflows/Codestyle-Check.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Codestyle-Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: Pre Commit
|
||||
if: ${{ github.repository_owner == 'PaddlePaddle' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_ID: ${{ github.event.pull_request.number }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
|
||||
steps:
|
||||
- name: Cleanup
|
||||
run: |
|
||||
rm -rf * .[^.]*
|
||||
|
||||
- name: Checkout base repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.base.ref }}
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Merge PR to test branch
|
||||
run: |
|
||||
git fetch origin pull/${PR_ID}/merge
|
||||
git checkout -b test FETCH_HEAD
|
||||
|
||||
- name: Setup python3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install pre-commit==4.2.0 cpplint==1.6.0 clang-format==13.0.0
|
||||
|
||||
- name: Check pre-commit
|
||||
env:
|
||||
SKIP_CLANG_TIDY_CHECK: "ON"
|
||||
run: |
|
||||
set +e
|
||||
bash -x tools/codestyle/pre_commit.sh;EXCODE=$?
|
||||
exit $EXCODE
|
173
.github/workflows/_build_linux.yml
vendored
Normal file
173
.github/workflows/_build_linux.yml
vendored
Normal file
@@ -0,0 +1,173 @@
|
||||
name: FastDeploy Linux GPU Build Task
|
||||
description: "FastDeploy packages build and upload"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
COMPILE_ARCH:
|
||||
description: "Build GPU Archs"
|
||||
required: true
|
||||
type: string
|
||||
default: "80,90"
|
||||
WITH_NIGHTLY_BUILD:
|
||||
description: "Enable nightly build mode (e.g. add date suffix to version)"
|
||||
required: false
|
||||
type: string
|
||||
default: "ON"
|
||||
FD_VERSION:
|
||||
description: "FastDeploy Package Version"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
UPLOAD:
|
||||
description: "Upload Package"
|
||||
required: false
|
||||
type: string
|
||||
default: "ON"
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
outputs:
|
||||
wheel_path:
|
||||
description: "Output path of the generated wheel"
|
||||
value: ${{ jobs.fd-build.outputs.wheel_path }}
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-Build]
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
IS_PR: ${{ github.event_name == 'pull_request' }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
- name: FastDeploy Build
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
compile_arch: ${{ inputs.COMPILE_ARCH }}
|
||||
fd_version: ${{ inputs.FD_VERSION }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
run: |
|
||||
set -x
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host \
|
||||
--cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/.ccache:/root/.ccache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "COMPILE_ARCH=${compile_arch}" \
|
||||
-e "FD_VERSION=${fd_version}" \
|
||||
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
if [[ -n "${FD_VERSION}" ]]; then
|
||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||
echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}"
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g")
|
||||
echo "Git Commit Time: $GIT_COMMIT_TIME"
|
||||
echo "Date Only: $DATE_ONLY"
|
||||
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
|
||||
fi
|
||||
python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install wheel
|
||||
# 编译RDMA
|
||||
export ENABLE_FD_RDMA=1
|
||||
bash build.sh 1 python false [${COMPILE_ARCH}]
|
||||
ls ./dist/*.whl
|
||||
'
|
||||
- name: Package Upload
|
||||
id: set_output
|
||||
env:
|
||||
compile_arch: ${{ inputs.COMPILE_ARCH }}
|
||||
run: |
|
||||
set -x
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]];then
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python --version
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
cd FastDeploy/dist/
|
||||
matches=($(ls fastdeploy*.whl))
|
||||
if [ ${#matches[@]} -ne 1 ]; then
|
||||
echo "Error: Found ${#matches[@]} matching files, expected exactly 1"
|
||||
exit 1
|
||||
fi
|
||||
fd_wheel_name=${matches[0]}
|
||||
echo "Found: $fd_wheel_name"
|
||||
tree -L 3
|
||||
python ${push_file} fastdeploy*.whl ${target_path}
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "wheel_path=${WHEEL_PATH}" >> $GITHUB_OUTPUT
|
78
.github/workflows/_clone_linux.yml
vendored
Normal file
78
.github/workflows/_clone_linux.yml
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
name: FastDeploy Code Clone
|
||||
description: "FastDeploy clone and upload"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
bos_dir:
|
||||
type: string
|
||||
required: false
|
||||
default: 'FastDeploy'
|
||||
outputs:
|
||||
repo_archive_url:
|
||||
description: "Compressed source code archive."
|
||||
value: ${{ jobs.code-clone.outputs.repo_archive_url }}
|
||||
jobs:
|
||||
code-clone:
|
||||
runs-on:
|
||||
group: HK-Clone
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request'
|
||||
&& github.event.pull_request.base.ref
|
||||
|| github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Merge PR (if needed)
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
run: |
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
echo "Fetching and merging PR..."
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge --no-ff pr/${{ github.event.pull_request.number }}
|
||||
echo "PR Branch log "
|
||||
git log --oneline -n 5 pr/${{ github.event.pull_request.number }}
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: paddle
|
||||
SK: paddle
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]];then
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
177
.github/workflows/_logprob_test_linux.yml
vendored
Normal file
177
.github/workflows/_logprob_test_linux.yml
vendored
Normal file
@@ -0,0 +1,177 @@
|
||||
name: Run FastDeploy LogProb Tests
|
||||
description: "Run FastDeploy LogProb Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
PADDLETEST_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
default: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
run_tests_logprob:
|
||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
||||
run: |
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
rm -rf /workspace/*
|
||||
'
|
||||
wget -q ${paddletest_archive_url}
|
||||
tar -xf PaddleTest.tar.gz
|
||||
rm -rf PaddleTest.tar.gz
|
||||
cd PaddleTest
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
- name: logprob test
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
chmod +x ./llm-deploy-linux-amd64
|
||||
./llm-deploy-linux-amd64 -python python3.10 \
|
||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||
-model_path /MODELDATA \
|
||||
--skip install
|
||||
|
||||
cd PaddleTest/framework/ServeTest
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
|
||||
set +e
|
||||
rm -rf ./baseline_output
|
||||
cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output
|
||||
LOGPROB_EXIT_CODE=0
|
||||
python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$?
|
||||
echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env
|
||||
curl -X POST http://localhost:${FLASK_PORT}/stop
|
||||
sleep 10s
|
||||
cat *result.log
|
||||
exit 0
|
||||
'
|
||||
if [ $? -ne 0 ];then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -f exit_code.env ]; then
|
||||
cat exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
- name: logprob test result
|
||||
if: ${{ env.LOGPROB_EXIT_CODE != 0 }}
|
||||
shell: bash
|
||||
run: |
|
||||
echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}"
|
||||
exit 8
|
148
.github/workflows/_pre_ce_test.yml
vendored
Normal file
148
.github/workflows/_pre_ce_test.yml
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
name: Pre-CE-Test
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_ce_cases:
|
||||
runs-on: [self-hosted, PRE_CE_RUN_2Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install ${fd_wheel_url}
|
||||
bash scripts/run_pre_ce.sh
|
||||
'
|
282
.github/workflows/_unit_test_coverage.yml
vendored
Normal file
282
.github/workflows/_unit_test_coverage.yml
vendored
Normal file
@@ -0,0 +1,282 @@
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
description: "Run FastDeploy Unit Tests and Coverage"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
run_tests_with_coverage:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
outputs:
|
||||
diff_cov_file_url: ${{ steps.cov_upload.outputs.diff_cov_file_url }}
|
||||
unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }}
|
||||
diff_cov_result_json_url: ${{ steps.cov_upload.outputs.diff_cov_result_json_url }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
- name: Run FastDeploy Unit Tests and Coverage
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
set -x
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --net=host \
|
||||
--name ${runner_name} \
|
||||
--cap-add=SYS_PTRACE --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
-e "BASE_REF=${BASE_REF}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install ${fd_wheel_url}
|
||||
if [ -d "test/plugins" ]; then
|
||||
cd test/plugins
|
||||
python setup.py install
|
||||
cd ../..
|
||||
else
|
||||
echo "Warning: test/plugins directory not found, skipping setup.py install"
|
||||
fi
|
||||
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
|
||||
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
|
||||
TEST_EXIT_CODE=0
|
||||
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
|
||||
coverage combine coveragedata/
|
||||
coverage xml -o python_coverage_all.xml
|
||||
COVERAGE_EXIT_CODE=0
|
||||
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
|
||||
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
|
||||
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
|
||||
'
|
||||
if [ -f FastDeploy/exit_code.env ]; then
|
||||
cat FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Upload unit resule and diff coverage to bos
|
||||
id: cov_upload
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
diff_cov_file="diff_coverage.xml"
|
||||
if [ -f ${diff_cov_file} ];then
|
||||
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
|
||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
|
||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
diff_cov_result_json="diff_coverage.json"
|
||||
if [ -f ${diff_cov_result_json} ];then
|
||||
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
|
||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
|
||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
unittest_result="test/failed_tests.log"
|
||||
if [ -s ${unittest_result} ];then
|
||||
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
|
||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
|
||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Check Unit Test Success
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
if [ "$TEST_EXIT_CODE" -eq 8 ]; then
|
||||
filename=$(basename "$unittest_failed_url")
|
||||
if [ -z "${unittest_failed_url}" ]; then
|
||||
echo "No diff unit failed file URL provided."
|
||||
else
|
||||
rm -rf "${filename}"
|
||||
wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
|
||||
fi
|
||||
echo "Unit tests failed (exit code 8)"
|
||||
if [ -f "${filename}" ];then
|
||||
echo "Failed test cases:"
|
||||
cat "${filename}"
|
||||
fi
|
||||
exit "$TEST_EXIT_CODE"
|
||||
fi
|
||||
echo "All tests passed"
|
||||
|
||||
- name: Verify Code Coverage Threshold (80%)
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then
|
||||
echo "Coverage generation failed (exit code 9)"
|
||||
filename=$(basename "$diff_cov_result_json_url")
|
||||
if [ -z "${diff_cov_result_json_url}" ]; then
|
||||
echo "No diff cov result file URL provided."
|
||||
else
|
||||
rm -rf "${filename}"
|
||||
wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
|
||||
fi
|
||||
if [ -f "${filename}" ];then
|
||||
echo "Failed test cases:"
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
jq . "${filename}"
|
||||
else
|
||||
cat "${filename}"
|
||||
fi
|
||||
fi
|
||||
exit "$COVERAGE_EXIT_CODE"
|
||||
fi
|
||||
echo "coverage passed"
|
||||
exit 0
|
||||
|
||||
diff_coverage_report:
|
||||
needs: run_tests_with_coverage
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: coverage diff file download
|
||||
shell: bash
|
||||
env:
|
||||
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
|
||||
run: |
|
||||
if [ -z "${diff_cov_file_url}" ]; then
|
||||
echo "No diff coverage file URL provided."
|
||||
exit 0
|
||||
fi
|
||||
wget "${diff_cov_file_url}" -O ./diff_coverage.xml || echo "Download cov file failed, but continuing..."
|
||||
- name: Upload diff coverage report
|
||||
if: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url != null && needs.run_tests_with_coverage.outputs.diff_cov_file_url != '' }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./diff_coverage.xml
|
||||
name: python diff coverage
|
||||
verbose: true
|
39
.github/workflows/approve.yml
vendored
Normal file
39
.github/workflows/approve.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: Approval
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
|
||||
jobs:
|
||||
Approval:
|
||||
name: Approval
|
||||
if: ${{ github.repository_owner == 'PaddlePaddle' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_ID: ${{ github.event.pull_request.number }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
steps:
|
||||
- name: Checkout base repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.base.ref }}
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Merge PR to test branch
|
||||
run: |
|
||||
git fetch origin pull/${PR_ID}/merge
|
||||
git checkout -b test FETCH_HEAD
|
||||
git log -n 3 --oneline
|
||||
git remote add upstream https://github.com/PaddlePaddle/FastDeploy.git
|
||||
git fetch upstream $BRANCH
|
||||
|
||||
- name: Setup python3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Run approval check script
|
||||
run: |
|
||||
bash scripts/check_approval.sh
|
89
.github/workflows/ci_gcu.yml
vendored
Normal file
89
.github/workflows/ci_gcu.yml
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
name: CI_GCU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-gcu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
CI_GCU:
|
||||
runs-on: [self-hosted, GCU-S60-8Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gcu_id="$last_char"
|
||||
else
|
||||
gcu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gcu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gcu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gcu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
echo "Install drivers..."
|
||||
cd /work/deps
|
||||
bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
|
||||
cd -
|
||||
docker run --rm --network=host --ipc=host -it --privileged \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/home:/home" \
|
||||
-v "/work:/work" \
|
||||
-e "MODEL_PATH=/work/models" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_gcu.sh
|
||||
"
|
84
.github/workflows/ci_iluvatar.yml
vendored
Normal file
84
.github/workflows/ci_iluvatar.yml
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
name: CI_ILUVATAR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-iluvatar-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
CI_ILUVATAR:
|
||||
runs-on: [self-hosted, IXUCA]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host --pid=host --cap-add=ALL --privileged --shm-size=64G \
|
||||
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/data1/fastdeploy:/data1/fastdeploy" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_iluvatar.sh
|
||||
"
|
@@ -1,17 +1,19 @@
|
||||
name: CI
|
||||
name: CI_XPU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}
|
||||
group: ${{ github.event.pull_request.number }}-xpu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, GPU-L20-4Card]
|
||||
CI_XPU:
|
||||
runs-on: [self-hosted, XPU-P800-8Card-release]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
@@ -22,14 +24,16 @@ jobs:
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
@@ -38,7 +42,7 @@ jobs:
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
@@ -51,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
@@ -59,7 +63,7 @@ jobs:
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
@@ -67,17 +71,17 @@ jobs:
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "/ssd4/GithubActions/ModelData:/ModelData:ro" \
|
||||
-v "/ssd4/GithubActions/CacheDir:/root/.cache" \
|
||||
-v "/ssd4/GithubActions/ConfigDir:/root/.config" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd3:/ssd3" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
--gpus device=${gpu_id} ${docker_image} /bin/bash -c "
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci.sh
|
||||
"
|
||||
bash scripts/run_ci_xpu.sh
|
||||
"
|
8
.github/workflows/gh-pages.yml
vendored
8
.github/workflows/gh-pages.yml
vendored
@@ -3,8 +3,6 @@ name: Deploy GitHub Pages
|
||||
on:
|
||||
push:
|
||||
branches: [ develop ]
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -17,8 +15,10 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang
|
||||
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: mkdocs gh-deploy --force --remote-name origin
|
||||
run: |
|
||||
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
|
||||
mkdocs gh-deploy --force --remote-name origin
|
||||
|
65
.github/workflows/pr_build_and_test.yml
vendored
Normal file
65
.github/workflows/pr_build_and_test.yml
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
name: PR Build and Test
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches: [develop, release/**]
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-${{ github.workflow }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
clone:
|
||||
name: FD-Clone-Linux
|
||||
uses: ./.github/workflows/_clone_linux.yml
|
||||
|
||||
build:
|
||||
name: FD-Build-Linux
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "89,90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
FD_VERSION: "0.0.0"
|
||||
|
||||
resultshow:
|
||||
name: Use Build Output
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}"
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
2
.gitignore
vendored
2
.gitignore
vendored
@@ -162,3 +162,5 @@ custom_ops/tmp*
|
||||
build
|
||||
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
@@ -3,20 +3,30 @@ default_install_hook_types:
|
||||
- commit-msg
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- commit-msg
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
- repo: https://github.com/psf/black.git
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: \.(py|pyi)$
|
||||
additional_dependencies: [toml]
|
||||
# 自动排序
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix, --line-length=120]
|
||||
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
|
||||
# # 拼写检查
|
||||
# - repo: https://github.com/codespell-project/codespell
|
||||
# rev: v2.4.1
|
||||
@@ -24,26 +34,13 @@ repos:
|
||||
# - id: codespell
|
||||
# additional_dependencies: ['tomli']
|
||||
# args: ['--toml', 'pyproject.toml']
|
||||
# 自动排序
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
args: [fix]
|
||||
args: ["-d", "MD029,MD031", fix]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
|
25
README.md
25
README.md
@@ -1,3 +1,4 @@
|
||||
English | [简体中文](README_CN.md)
|
||||
<p align="center">
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
|
||||
</p>
|
||||
@@ -8,20 +9,26 @@
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
|
||||
## News
|
||||
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
|
||||
|
||||
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
**[2025-06] 🔥 Released FastDeploy v2.0:** Supports inference and deployment for ERNIE 4.5. Furthermore, we open-source an industrial-grade PD disaggregation with context caching, dynamic role switching for effective resource utilization to further enhance inference performance for MoE models.
|
||||
|
||||
@@ -43,14 +50,15 @@
|
||||
|
||||
## Installation
|
||||
|
||||
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
|
||||
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions:
|
||||
|
||||
- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md)
|
||||
- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md)
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU and MetaX GPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
## Get Started
|
||||
|
||||
@@ -61,18 +69,19 @@ Learn how to use FastDeploy through our documentation:
|
||||
- [Offline Inference Development](./docs/offline_inference.md)
|
||||
- [Online Service Deployment](./docs/online_serving/README.md)
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
- [Best Practices](./docs/best_practices/README.md)
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
|
94
README_CN.md
Normal file
94
README_CN.md
Normal file
@@ -0,0 +1,94 @@
|
||||
[English](README.md) | 简体中文
|
||||
<p align="center">
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href=""><img src="https://img.shields.io/badge/python-3.10-aff.svg"></a>
|
||||
<a href=""><img src="https://img.shields.io/badge/os-linux-pink.svg"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/FastDeploy?color=9ea"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/"><b> 安装指导 </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/quick_start"><b> 快速入门 </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/supported_models/"><b> 支持模型列表 </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
|
||||
|
||||
## 最新活动
|
||||
**[2025-08] 🔥 FastDeploy v2.1 全新发布:** 全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
## 关于
|
||||
|
||||
**FastDeploy** 是基于飞桨(PaddlePaddle)的大语言模型(LLM)与视觉语言模型(VLM)推理部署工具包,提供**开箱即用的生产级部署方案**,核心技术特性包括:
|
||||
|
||||
- 🚀 **负载均衡式PD分解**:工业级解决方案,支持上下文缓存与动态实例角色切换,在保障SLO达标和吞吐量的同时优化资源利用率
|
||||
- 🔄 **统一KV缓存传输**:轻量级高性能传输库,支持智能NVLink/RDMA选择
|
||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
||||
|
||||
## 要求
|
||||
|
||||
- 操作系统: Linux
|
||||
- Python: 3.10 ~ 3.12
|
||||
|
||||
## 安装
|
||||
|
||||
FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU**、**天数(Iluvatar)GPU**、**燧原(Enflame)GCU**、**海光(Hygon)DCU** 以及其他硬件上进行推理部署。详细安装说明如下:
|
||||
|
||||
- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md)
|
||||
- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md)
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 和 沐曦(MetaX)GPU 在内的其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
## 入门指南
|
||||
|
||||
通过我们的文档了解如何使用 FastDeploy:
|
||||
- [10分钟快速部署](./docs/zh/get_started/quick_start.md)
|
||||
- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md)
|
||||
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
|
||||
- [离线推理](./docs/zh/offline_inference.md)
|
||||
- [在线服务](./docs/zh/online_serving/README.md)
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
- [最佳实践](./docs/zh/best_practices/README.md)
|
||||
|
||||
## 支持模型列表
|
||||
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|
||||
## 进阶用法
|
||||
|
||||
- [量化](./docs/zh/quantization/README.md)
|
||||
- [分离式部署](./docs/zh/features/disaggregated.md)
|
||||
- [投机解码](./docs/zh/features/speculative_decoding.md)
|
||||
- [前缀缓存](./docs/zh/features/prefix_caching.md)
|
||||
- [分块预填充](./docs/zh/features/chunked_prefill.md)
|
||||
|
||||
## 致谢
|
||||
|
||||
FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。
|
@@ -41,7 +41,10 @@ python -m pip install -r requirements.txt
|
||||
--metric-percentiles 80,95,99,99.9,99.95,99.99:性能结果中展示的性能指标分位值
|
||||
--num-prompts 1:总计发送多少条请求
|
||||
--max-concurrency 1:压测并发数
|
||||
--save-result:开启结果保存,结果文件会存入json
|
||||
--save-result:开启结果保存,结果文件会存入json,默认False不保存
|
||||
--debug:开启debug模式,逐条打印payload和output内容,默认False
|
||||
--shuffle:是否打乱数据集,默认False不打乱
|
||||
--seed:打乱数据集时的随机种子,默认0
|
||||
```
|
||||
|
||||
##### /v1/chat/completions接口压测单条数据调试
|
||||
@@ -105,3 +108,30 @@ python benchmark_serving.py \
|
||||
--save-result > infer_log.txt 2>&1 &
|
||||
```
|
||||
|
||||
### 投机解码性能测试工具
|
||||
|
||||
#### 使用方式:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mtp.py \
|
||||
--host 127.0.0.1 --port 8000 \
|
||||
--max-concurrency 16 32 64 96 --num-prompts 256 \
|
||||
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
|
||||
--s_itl-base-model 15.88 22.84 16.47 16.93 \
|
||||
--dataset-name EBChat \
|
||||
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
|
||||
```
|
||||
|
||||
#### 参数说明
|
||||
|
||||
```bash
|
||||
--host:服务ip地址,用于组url
|
||||
--port:服务HTTP端口,用于组url
|
||||
--max-concurrency:测试并发数
|
||||
--num-prompts:总计发送多少条请求
|
||||
--acceptance-rate:投机解码的模拟接受率
|
||||
--draft-token-steps:投机解码的步数
|
||||
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
||||
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
||||
--dataset-path:测试数据集路径
|
||||
```
|
||||
|
@@ -29,13 +29,14 @@ from typing import Optional
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""Input for requesting LLMs via API"""
|
||||
|
||||
no: int
|
||||
prompt: str
|
||||
history_QA: Optional[dict]
|
||||
hyper_parameters: dict
|
||||
@@ -49,11 +50,14 @@ class RequestFuncInput:
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
debug: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""Output for requesting LLMs via API"""
|
||||
|
||||
no: int = 0
|
||||
generated_text: str = ""
|
||||
reasoning_content: str = ""
|
||||
success: bool = False
|
||||
@@ -64,7 +68,7 @@ class RequestFuncOutput:
|
||||
itl: list = field(default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
prompt_tokens: int = 0 # 推理侧返回输入token数
|
||||
prompt_tokens: int = 0 # 推理侧返回输入token数
|
||||
error: str = ""
|
||||
|
||||
|
||||
@@ -74,22 +78,19 @@ async def async_request_eb_openai_chat_completions(
|
||||
) -> RequestFuncOutput:
|
||||
"""Request an LLM using EB OpenAI"""
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("completions", "profile")
|
||||
), "OpenAI Chat Completions API URL must end with 'completions'."
|
||||
assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": "default",
|
||||
"model": request_func_input.model,
|
||||
"messages": request_func_input.history_QA,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True
|
||||
"continuous_usage_stats": True,
|
||||
},
|
||||
}
|
||||
# 超参由yaml传入
|
||||
@@ -97,6 +98,10 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
if request_func_input.debug:
|
||||
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
@@ -104,21 +109,20 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = 0
|
||||
output.no = request_func_input.no
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, type(chunk))
|
||||
timestamp = time.perf_counter()
|
||||
@@ -132,21 +136,20 @@ async def async_request_eb_openai_chat_completions(
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
# cached_tokens
|
||||
output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
|
||||
output.prompt_len = (
|
||||
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
||||
)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
output.generated_text += content or ""
|
||||
output.reasoning_content += reason_content or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage", {}):
|
||||
output.output_tokens = usage.get("completion_tokens", 0)
|
||||
output.prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -159,7 +162,12 @@ async def async_request_eb_openai_chat_completions(
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
error_text = await response.text()
|
||||
print("####error response:", error_text, "####payload:", payload)
|
||||
print(
|
||||
"####error response:",
|
||||
error_text,
|
||||
"####payload:",
|
||||
payload,
|
||||
)
|
||||
output.error = error_text or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
@@ -173,6 +181,8 @@ async def async_request_eb_openai_chat_completions(
|
||||
f.write(str(output) + "\n")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
if request_func_input.debug:
|
||||
print("#####final_output:", output)
|
||||
return output
|
||||
|
||||
|
||||
@@ -186,15 +196,14 @@ async def async_request_eb_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": "default",
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True
|
||||
"continuous_usage_stats": True,
|
||||
},
|
||||
}
|
||||
# 超参由yaml传入
|
||||
@@ -202,19 +211,25 @@ async def async_request_eb_openai_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
if request_func_input.debug:
|
||||
print("payload:", json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.no = request_func_input.no
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
@@ -222,10 +237,10 @@ async def async_request_eb_openai_completions(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, chunk.usage)
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
@@ -235,35 +250,40 @@ async def async_request_eb_openai_completions(
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += text or ""
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
generated_text += text or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage"):
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.prompt_tokens = usage.get("prompt_tokens")
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
|
||||
)
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
|
||||
if output.generated_text == "":
|
||||
output.success = False
|
||||
output.error = "No generated text found!"
|
||||
else:
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -272,6 +292,9 @@ async def async_request_eb_openai_completions(
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if request_func_input.debug:
|
||||
print(f"final_output:{output}")
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
@@ -285,8 +308,7 @@ async def async_request_tgi(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
@@ -333,8 +355,7 @@ async def async_request_tgi(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output.arrival_time.append(data["arrival_time"])
|
||||
@@ -363,8 +384,7 @@ async def async_request_trt_llm(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
@@ -389,8 +409,7 @@ async def async_request_trt_llm(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data:")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
output.generated_text += data["text_output"]
|
||||
@@ -402,8 +421,7 @@ async def async_request_trt_llm(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -428,8 +446,7 @@ async def async_request_deepspeed_mii(
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""Request an LLM using Deepspeed MII"""
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
@@ -447,19 +464,16 @@ async def async_request_deepspeed_mii(
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as response:
|
||||
async with session.post(url=request_func_input.api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
parsed_resp = await response.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
if "choices" in parsed_resp:
|
||||
output.generated_text = parsed_resp["choices"][0][
|
||||
"text"]
|
||||
output.generated_text = parsed_resp["choices"][0]["text"]
|
||||
elif "text" in parsed_resp:
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
else:
|
||||
output.error = ("Unexpected response format: "
|
||||
"neither 'choices' nor 'text' found")
|
||||
output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
|
||||
output.success = False
|
||||
output.success = True
|
||||
else:
|
||||
@@ -485,26 +499,22 @@ async def async_request_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
|
||||
"prompt": request_func_input.prompt,
|
||||
# "temperature": 0.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
#"stream_options": {
|
||||
# "stream_options": {
|
||||
# "include_usage": True,
|
||||
#},
|
||||
# },
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@@ -513,8 +523,7 @@ async def async_request_openai_completions(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
@@ -522,8 +531,7 @@ async def async_request_openai_completions(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, type(chunk))
|
||||
data = json.loads(chunk)
|
||||
@@ -544,21 +552,19 @@ async def async_request_openai_completions(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
|
||||
)
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
@@ -581,25 +587,24 @@ async def async_request_openai_audio(
|
||||
"""Request an LLM using OpenAI"""
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
("transcriptions", "translations")
|
||||
), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
"stream_continuous_usage_stats": True,
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
@@ -614,9 +619,9 @@ async def async_request_openai_audio(
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
@@ -628,24 +633,20 @@ async def async_request_openai_audio(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
async with session.post(url=api_url, data=form, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
@@ -653,13 +654,11 @@ async def async_request_openai_audio(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -693,8 +692,11 @@ ASYNC_REQUEST_FUNCS = {
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_eb_openai_chat_completions)
|
||||
k
|
||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v
|
||||
in (
|
||||
async_request_openai_completions,
|
||||
async_request_eb_openai_chat_completions,
|
||||
)
|
||||
]
|
||||
|
||||
|
@@ -26,9 +26,9 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from PIL import Image
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +39,7 @@ class SampleRequest:
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
no: int
|
||||
prompt: Union[str, Any]
|
||||
history_QA: Union[str, Any]
|
||||
json_data: Optional[dict]
|
||||
@@ -48,6 +49,7 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
"""BenchmarkDataset"""
|
||||
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
@@ -55,6 +57,7 @@ class BenchmarkDataset(ABC):
|
||||
self,
|
||||
dataset_path: Optional[str] = None,
|
||||
random_seed: int = DEFAULT_SEED,
|
||||
shuffle: bool = False,
|
||||
hyperparameter_path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -68,9 +71,9 @@ class BenchmarkDataset(ABC):
|
||||
self.dataset_path = dataset_path
|
||||
# Set the random seed, ensuring that a None value is replaced with the
|
||||
# default seed.
|
||||
self.random_seed = (random_seed
|
||||
if random_seed is not None else self.DEFAULT_SEED)
|
||||
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
||||
self.data = None
|
||||
self.shuffle = shuffle
|
||||
self.hyperparameter_path = hyperparameter_path
|
||||
self.hyperparameters = {}
|
||||
|
||||
@@ -85,8 +88,7 @@ class BenchmarkDataset(ABC):
|
||||
NotImplementedError: If a subclass does not implement this method.
|
||||
"""
|
||||
# TODO (jenniferzhao): add support for downloading data
|
||||
raise NotImplementedError(
|
||||
"load_data must be implemented in subclasses.")
|
||||
raise NotImplementedError("load_data must be implemented in subclasses.")
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, num_requests: int) -> list[SampleRequest]:
|
||||
@@ -105,8 +107,7 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||
num_requests: int) -> None:
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
number.
|
||||
@@ -117,11 +118,9 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
additional = random.choices(requests, k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
||||
|
||||
|
||||
def is_valid_sequence(
|
||||
@@ -141,14 +140,12 @@ def is_valid_sequence(
|
||||
"""
|
||||
# Check for invalid conditions
|
||||
prompt_too_short = prompt_len < min_len
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len
|
||||
< min_len)
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
|
||||
prompt_too_long = prompt_len > max_prompt_len
|
||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||
|
||||
# Return True if none of the invalid conditions are met
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long
|
||||
or combined_too_long)
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
|
||||
|
||||
|
||||
def process_image(image: Any) -> Mapping[str, Any]:
|
||||
@@ -171,28 +168,25 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(image, dict) and 'bytes' in image:
|
||||
image = Image.open(BytesIO(image['bytes']))
|
||||
if isinstance(image, dict) and "bytes" in image:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGB")
|
||||
with io.BytesIO() as image_data:
|
||||
image.save(image_data, format="JPEG")
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||
}
|
||||
|
||||
if isinstance(image, str):
|
||||
image_url = (image if image.startswith(
|
||||
("http://", "file://")) else f"file://{image}")
|
||||
image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes.")
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
|
||||
)
|
||||
|
||||
|
||||
class EBDataset(BenchmarkDataset):
|
||||
@@ -219,6 +213,10 @@ class EBDataset(BenchmarkDataset):
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = [json.loads(i.strip()) for i in f.readlines()]
|
||||
|
||||
if self.shuffle:
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
num_requests: int,
|
||||
@@ -229,6 +227,7 @@ class EBDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -242,15 +241,17 @@ class EBDataset(BenchmarkDataset):
|
||||
new_output_len = int(entry["max_dec_len"])
|
||||
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
prompt=prompt,
|
||||
prompt_len=self.prompt_len,
|
||||
history_QA=[],
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
)
|
||||
)
|
||||
cnt += 1
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
@@ -261,6 +262,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||
sample requests based on conversation turns.
|
||||
"""
|
||||
|
||||
prompt_len: int
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
@@ -274,6 +276,10 @@ class EBChatDataset(BenchmarkDataset):
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = [json.loads(i.strip()) for i in f.readlines()]
|
||||
|
||||
if self.shuffle:
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
num_requests: int,
|
||||
@@ -284,6 +290,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -293,17 +300,18 @@ class EBChatDataset(BenchmarkDataset):
|
||||
new_output_len = int(entry.get("max_tokens", 12288))
|
||||
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
json_data=json_data,
|
||||
prompt=prompt,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
)
|
||||
)
|
||||
cnt += 1
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
178
benchmarks/benchmark_mtp.py
Normal file
178
benchmarks/benchmark_mtp.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from benchmark_dataset import EBChatDataset, EBDataset
|
||||
from benchmark_serving import benchmark
|
||||
|
||||
|
||||
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
|
||||
dataset_mapping = {
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}") from err
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
|
||||
selected_percentile_metrics = ["s_itl"]
|
||||
selected_percentiles = []
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend="openai-chat",
|
||||
api_url=f"{base_url}/v1/chat/completions",
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
model_name="default",
|
||||
input_requests=input_requests,
|
||||
hyper_parameters={},
|
||||
logprobs=None,
|
||||
request_rate=float("inf"),
|
||||
burstiness=1.0,
|
||||
disable_tqdm=disable_tqdm,
|
||||
profile=False,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
ignore_eos=False,
|
||||
goodput_config_dict=None,
|
||||
max_concurrency=max_concurrency,
|
||||
lora_modules=None,
|
||||
extra_body=None,
|
||||
)
|
||||
)
|
||||
|
||||
record = {
|
||||
"mean_s_itl_ms": results["mean_s_itl_ms"],
|
||||
}
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
||||
|
||||
tmp = 0.0
|
||||
for i in range(draft_token_step):
|
||||
tmp += pow(acceptance_rate, i + 1)
|
||||
|
||||
r_ac = tmp / (1 + tmp)
|
||||
|
||||
return t_ori / ((1 - r_ac) * t_mtp)
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
|
||||
|
||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
send_one_batch(
|
||||
base_url,
|
||||
max_concurrency,
|
||||
input_requests[0:max_concurrency],
|
||||
True,
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||
|
||||
metric_header = "Speed up"
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
for draft_token_step in args.draft_token_steps:
|
||||
speedup = calculate_speedup(
|
||||
args.acceptance_rate,
|
||||
draft_token_step,
|
||||
s_itl,
|
||||
record["mean_s_itl_ms"],
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="8000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16, 32),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acceptance-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-token-steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--s_itl-base-model",
|
||||
type=float,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="EBChat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
File diff suppressed because it is too large
Load Diff
@@ -24,9 +24,11 @@ import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any]) -> list:
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any],
|
||||
) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
@@ -54,12 +56,10 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get(
|
||||
"tensor_parallel_size")
|
||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"][
|
||||
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
|
||||
records.append(record)
|
||||
|
||||
@@ -68,6 +68,7 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
"""InfEncoder"""
|
||||
|
||||
def clear_inf(self, o: Any):
|
||||
"""clear_inf"""
|
||||
if isinstance(o, dict):
|
||||
@@ -87,4 +88,3 @@ def write_to_json(filename: str, records: list) -> None:
|
||||
"""write_to_json"""
|
||||
with open(filename, "w") as f:
|
||||
json.dump(records, f, cls=InfEncoder)
|
||||
|
||||
|
1173
benchmarks/quick_benchmark.py
Normal file
1173
benchmarks/quick_benchmark.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,3 +3,4 @@ tqdm
|
||||
numpy
|
||||
Pillow
|
||||
pyyaml
|
||||
requests
|
||||
|
@@ -7,4 +7,4 @@ tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint4
|
||||
reasoning_parser: ernie-45-vl
|
||||
reasoning_parser: ernie-45-vl
|
||||
|
@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
|
||||
pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
||||
engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
pd_comm_port: "2333"
|
||||
|
@@ -3,3 +3,4 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
quantization: wint4
|
||||
|
@@ -10,4 +10,4 @@ engine_worker_queue_port: 6677
|
||||
num_gpu_blocks_override: 1024
|
||||
cache_transfer_protocol: "rdma"
|
||||
rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
|
||||
pd_comm_port: "2334"
|
||||
pd_comm_port: "2334"
|
||||
|
@@ -10,4 +10,4 @@ splitwise_role: decode
|
||||
engine_worker_queue_port: 6678
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7671,7672,7673,7674"
|
||||
pd_comm_port: "2334"
|
||||
pd_comm_port: "2334"
|
||||
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
||||
engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
pd_comm_port: "2333"
|
||||
|
@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
|
||||
pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
||||
engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
pd_comm_port: "2333"
|
||||
|
@@ -3,3 +3,4 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wfp8afp8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,4 @@ max_num_seqs: 75
|
||||
gpu_memory_utilization: 0.85
|
||||
kv_cache_ratio: 0.75
|
||||
quantization: wint4
|
||||
tensor_parallel_size: 4
|
||||
tensor_parallel_size: 4
|
||||
|
@@ -3,4 +3,4 @@ max_num_seqs: 25
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.75
|
||||
quantization: wint8
|
||||
tensor_parallel_size: 4
|
||||
tensor_parallel_size: 4
|
||||
|
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
metadata:
|
||||
min_tokens: 32
|
||||
max_tokens: 33
|
@@ -5,4 +5,4 @@ metadata:
|
||||
max_tokens: 12288
|
||||
repetition_penalty: 1.05
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
presence_penalty: 0
|
||||
|
@@ -5,4 +5,4 @@ metadata:
|
||||
max_tokens: 12288
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 1.5
|
||||
presence_penalty: 1.5
|
||||
|
11
benchmarks/yaml/request_yaml/vLLM_default.yaml
Normal file
11
benchmarks/yaml/request_yaml/vLLM_default.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
top_p: 1.0
|
||||
temperature: 1.0
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 30721
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
skip_special_tokens: false
|
||||
chat_template_kwargs:
|
||||
enable_thinking: true
|
@@ -3,4 +3,4 @@ max_num_seqs: 64
|
||||
gpu_memory_utilization: 0.9
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
reasoning_parser: ernie-x1
|
||||
reasoning_parser: ernie-x1
|
||||
|
51
build.sh
51
build.sh
@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
|
||||
PYTHON_VERSION=${2:-"python"}
|
||||
export python=$PYTHON_VERSION
|
||||
FD_CPU_USE_BF16=${3:-"false"}
|
||||
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
|
||||
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
|
||||
# These will be translated to 90a / 100a in setup_ops.py for specific features.
|
||||
FD_BUILDING_ARCS=${4:-""}
|
||||
|
||||
|
||||
@@ -74,8 +77,10 @@ function copy_ops(){
|
||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||
if [ "$is_rocm" = "True" ]; then
|
||||
DEVICE_TYPE="rocm"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "ROCM ops have been copy to fastdeploy"
|
||||
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
@@ -104,6 +109,23 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||
if [ "$if_corex" = "True" ]; then
|
||||
DEVICE_TYPE="iluvatar-gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
|
||||
if [ "$is_gcu" = "True" ]; then
|
||||
DEVICE_TYPE="gcu"
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
|
||||
echo -e "gcu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
@@ -163,17 +185,24 @@ function build_and_install() {
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
|
||||
}
|
||||
|
||||
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
|
||||
cd $DIST_DIR
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
|
||||
if [ $? -ne 0 ]; then
|
||||
cd ..
|
||||
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
|
||||
exit 1
|
||||
function version_info() {
|
||||
output_file="fastdeploy/version.txt"
|
||||
fastdeploy_git_commit_id=$(git rev-parse HEAD)
|
||||
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
|
||||
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
|
||||
cuda_version="nvcc-not-installed"
|
||||
if command -v nvcc &> /dev/null; then
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
fi
|
||||
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
|
||||
cd ..
|
||||
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
|
||||
|
||||
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
|
||||
echo "Paddle version: $paddle_version" >> $output_file
|
||||
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
|
||||
echo "CUDA version: $cuda_version" >> $output_file
|
||||
echo "CXX compiler version: $cxx_version" >> $output_file
|
||||
}
|
||||
|
||||
function cleanup() {
|
||||
@@ -207,6 +236,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
set -e
|
||||
|
||||
init
|
||||
version_info
|
||||
build_and_install_ops
|
||||
build_and_install
|
||||
cleanup
|
||||
@@ -237,6 +267,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
else
|
||||
init
|
||||
build_and_install_ops
|
||||
version_info
|
||||
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
fi
|
||||
|
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
@@ -46,8 +46,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -165,8 +165,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
lambda_batch_ids,
|
||||
lambda_tile_ids_per_batch,
|
||||
@@ -202,8 +202,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
@@ -274,8 +274,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -297,8 +297,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -322,8 +322,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -346,8 +346,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -403,8 +403,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
@@ -473,8 +473,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -550,8 +550,8 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
@@ -610,8 +610,8 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
@@ -688,8 +688,8 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -114,8 +114,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -405,7 +404,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -477,8 +476,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -775,8 +773,8 @@ void MultiQueryAppendAttention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -882,7 +880,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -939,7 +937,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -974,7 +972,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1009,7 +1007,8 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1062,12 +1061,11 @@ void MultiQueryAppendAttention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
|
||||
false,
|
||||
@@ -1103,7 +1101,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1162,8 +1160,8 @@ void MultiQueryAppendAttention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1171,7 +1169,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1207,10 +1205,10 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1227,14 +1225,14 @@ void MultiQueryAppendAttention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1242,10 +1240,11 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1289,8 +1288,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1352,8 +1351,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -144,8 +144,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -504,7 +503,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -601,8 +600,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -962,8 +960,8 @@ void MultiQueryAppendC4Attention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1088,7 +1086,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1151,7 +1149,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1186,7 +1184,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1221,7 +1219,8 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1286,10 +1285,11 @@ void MultiQueryAppendC4Attention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1333,7 +1333,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1393,15 +1393,15 @@ void MultiQueryAppendC4Attention(
|
||||
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1409,7 +1409,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1444,10 +1444,10 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1464,14 +1464,14 @@ void MultiQueryAppendC4Attention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1479,10 +1479,11 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1526,8 +1527,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1593,8 +1594,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -151,8 +151,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -473,7 +472,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -575,8 +574,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -899,8 +897,8 @@ void MultiQueryAppendC8Attention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1054,7 +1052,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1111,7 +1109,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1146,7 +1144,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1181,7 +1179,8 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1255,10 +1254,10 @@ void MultiQueryAppendC8Attention(
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1317,7 +1316,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1378,8 +1377,8 @@ void MultiQueryAppendC8Attention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1387,7 +1386,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1417,10 +1416,10 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1437,14 +1436,14 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1452,10 +1451,11 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1499,8 +1499,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1564,8 +1564,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -1852,7 +1852,7 @@ __global__ void merge_multi_chunks_kernel(
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
T* __restrict__ out,
|
||||
@@ -1866,8 +1866,7 @@ __global__ void merge_multi_chunks_kernel(
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, hid = threadIdx.y;
|
||||
const int qid = blockIdx.x;
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -2111,7 +2110,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2127,7 +2126,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int bid = blockIdx.x, hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
@@ -2240,7 +2239,8 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2259,9 +2259,8 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
|
@@ -40,8 +40,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -85,8 +85,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -130,8 +130,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -175,8 +175,8 @@ void CascadeAppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -211,8 +211,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -246,8 +246,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -281,8 +281,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -316,8 +316,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "multi_head_latent_attention_kernel.h"
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
T m;
|
||||
T d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_t() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
T other_m,
|
||||
T other_d) {
|
||||
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
|
||||
T m_prev = m, d_prev = d;
|
||||
m = m_prev > other_m ? m_prev : other_m;
|
||||
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
|
||||
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1 + other_o[i] * scale2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
|
||||
struct softmax_state_ts {
|
||||
uint32_t num_tiles_ = num_tiles;
|
||||
AlignedVector<T, vec_size> o[num_tiles];
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_ts() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; i++) {
|
||||
o[tile_id][i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
|
||||
__device__ __forceinline__ void produce_kv(CacheT *smem,
|
||||
CacheT *kv_base_gptr,
|
||||
const int * block_table_smem,
|
||||
const uint32_t seq_offset_gmem,
|
||||
const uint32_t seq_offset_smem,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t tidx,
|
||||
const uint32_t chunk_start,
|
||||
const uint32_t chunk_end) {
|
||||
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
|
||||
if (block_id < 0) {
|
||||
block_id = 0;
|
||||
}
|
||||
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
|
||||
// 8/16 T/int8 each time
|
||||
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
|
||||
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
|
||||
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
|
||||
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
|
||||
seq_offset_gmem < chunk_end
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
|
||||
const CacheT* k_smem,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
const uint32_t gid,
|
||||
const float scale,
|
||||
float *s,
|
||||
softmax_state_ts<vec_size, T, num_tile_v>& st) {
|
||||
const CacheT* smem;
|
||||
AlignedVector<T, vec_size> q_vec;
|
||||
AlignedVector<T, vec_size> k_vec;
|
||||
float m_prev = st.m;
|
||||
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
|
||||
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
if (iter_base + j < iter_bound) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = 0.f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = 0.f;
|
||||
}
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
|
||||
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
|
||||
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
st.m = st.m > s[j] ? st.m : s[j];
|
||||
}
|
||||
|
||||
// T o_scale = hexp(m_prev - st.m);
|
||||
float o_scale = __expf(m_prev - st.m);
|
||||
st.d *= o_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
// s[j] = hexp(s[j] - st.m);
|
||||
s[j] = __expf(s[j] - st.m);
|
||||
st.d += s[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
st.o[tile_id][i] *= o_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_sv(const float *s,
|
||||
const CacheT *base_v_smem,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
softmax_state_ts<vec_size, T, num_tile>& st) {
|
||||
const CacheT* v_smem;
|
||||
AlignedVector<T, vec_size> v_vec;
|
||||
#pragma unroll
|
||||
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
|
||||
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
|
||||
uint32_t tile_id = vid / bdx;
|
||||
#pragma unroll
|
||||
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
|
||||
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
@@ -0,0 +1,560 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decode_attention_func.cuh"
|
||||
|
||||
#define CHECK(call) \
|
||||
do \
|
||||
{ \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) \
|
||||
{ \
|
||||
printf("CUDA Error:\n"); \
|
||||
printf(" File: %s\n", __FILE__); \
|
||||
printf(" Line %d:\n", __LINE__); \
|
||||
printf(" Error code:%d\n", error_code); \
|
||||
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
}while(0)
|
||||
|
||||
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
|
||||
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const float in_scale,
|
||||
const int num_chunks,
|
||||
const int chunk_size,
|
||||
const int max_seq_len,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int qid = blockIdx.x, hid = blockIdx.y;
|
||||
const int seq_len_q = seq_lens_q[qid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[qid];
|
||||
if (seq_len_kv == 0) return;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
|
||||
return;
|
||||
}
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = cu_seqlens_q[qid];
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
T m;
|
||||
T d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
// merge per ty
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
T m_prev = m;
|
||||
T d_prev = d;
|
||||
const T m_now = multi_m[offset];
|
||||
const T d_now = multi_d[offset];
|
||||
m = m_prev > m_now ? m_prev : m_now;
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma once
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
|
||||
// merge bdy
|
||||
softmax_state_t<vec_size, T> st{};
|
||||
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
|
||||
#pragma once
|
||||
for (int i = 0; i < iter_num; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
AlignedVector<OutT, vec_size> out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = static_cast<OutT>(st.o[i]);
|
||||
}
|
||||
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
|
||||
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
|
||||
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
|
||||
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||
CacheT * __restrict__ cache_v,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
|
||||
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
|
||||
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
|
||||
OutT * __restrict__ out) {
|
||||
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t bid = bidx, gid = threadIdx.y;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
|
||||
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
|
||||
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
|
||||
|
||||
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
|
||||
|
||||
const int *block_table_now = block_table + bid * max_block_num_per_seq;
|
||||
|
||||
const uint32_t num_chunks = gridDim.y;
|
||||
const uint32_t chunk_id = blockIdx.y;
|
||||
const uint32_t q_len = seq_lens_q[bid];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
|
||||
if (kv_len <= 0) {
|
||||
return;
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = cu_seqlens_q[bid];
|
||||
const uint32_t q_write_idx = cu_seqlens_q[bid];
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
|
||||
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
|
||||
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
|
||||
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
using VecT = AlignedVector<T, VEC_SIZE>;
|
||||
VecT q_vec;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
|
||||
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
|
||||
q_vec[i] *= scale;
|
||||
}
|
||||
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
|
||||
}
|
||||
|
||||
|
||||
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
|
||||
uint32_t stage_idx = 0;
|
||||
constexpr int loop_times = DEAL_EACH_TIME / bdy;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
|
||||
const uint32_t k_seq_id = chunk_start + k_seq_offset;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
k_seq_id,
|
||||
k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
|
||||
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
|
||||
float s[DEAL_EACH_TIME];
|
||||
|
||||
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
wait_group<NUM_STAGES - 1>();
|
||||
__syncthreads();
|
||||
// compute qk
|
||||
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
|
||||
cu_q_smem,
|
||||
kv_smem,
|
||||
chunk_start + iter * DEAL_EACH_TIME,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
gid,
|
||||
scale,
|
||||
s,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
// compute sv
|
||||
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
|
||||
s,
|
||||
kv_smem,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = j * bdy + gid;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
|
||||
stage_idx * DEAL_EACH_TIME + k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// normize if not partition_kv
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
|
||||
const uint32_t tile_id = vid / bdx;
|
||||
if (!partition_kv || num_chunk_this_seq == 1) {
|
||||
st.normalize(tile_id);
|
||||
}
|
||||
if (partition_kv && num_chunk_this_seq > 1) {
|
||||
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
|
||||
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
tmp_m[head_idx] = st.m;
|
||||
tmp_d[head_idx] = st.d;
|
||||
} else {
|
||||
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
||||
void MultiQueryDecoderAttention(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
cudaStream_t &stream,
|
||||
const paddle::Tensor &q,
|
||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float rope_scale,
|
||||
const float rope_theta,
|
||||
const float softmax_scale,
|
||||
const float in_scale,
|
||||
paddle::Tensor *out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
constexpr int num_stages = NUM_STAGE;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
|
||||
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
|
||||
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
|
||||
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
|
||||
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
|
||||
|
||||
constexpr int blocky = GROUP_SIZE;
|
||||
const int gridx = bsz;
|
||||
|
||||
constexpr int num_threads = blockx * blocky;
|
||||
|
||||
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
|
||||
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
|
||||
uint32_t cache_smem_bytes = 0;
|
||||
|
||||
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
|
||||
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
|
||||
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
|
||||
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
|
||||
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
|
||||
|
||||
dim3 grids(gridx, num_chunks, kv_num_heads);
|
||||
dim3 blocks(blockx, blocky);
|
||||
if (num_chunks <= 1) {
|
||||
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
|
||||
cache_vec_size, blockx, blocky>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
} else {
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
|
||||
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
|
||||
constexpr int mblockx = HEAD_DIM_V / vec_size;
|
||||
constexpr int bdy = 256 / mblockx;
|
||||
dim3 grids_merge(bsz, num_heads);
|
||||
dim3 blocks_merge(mblockx, bdy);
|
||||
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
in_scale,
|
||||
num_chunks,
|
||||
chunk_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
}
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim_qk = meta_data.head_dims;
|
||||
const auto head_dim_v = meta_data.head_dims_v;
|
||||
const float rope_scale = 0.0;
|
||||
const float rope_theta = 0.0;
|
||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||
|
||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -28,8 +28,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -134,8 +134,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -254,8 +254,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -366,8 +366,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -498,8 +498,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -745,8 +745,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1047,8 +1047,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1346,8 +1346,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1739,8 +1739,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2034,8 +2034,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2362,8 +2362,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2732,8 +2732,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
|
@@ -21,8 +21,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -57,8 +57,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -79,8 +79,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -102,8 +102,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -125,8 +125,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -149,8 +149,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -182,8 +182,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -207,8 +207,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -232,8 +232,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -257,8 +257,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -282,8 +282,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -317,8 +317,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -344,8 +344,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -371,8 +371,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -398,8 +398,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -424,8 +424,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -471,8 +471,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -503,8 +503,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -536,8 +536,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -570,8 +570,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -603,8 +603,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -650,8 +650,8 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -677,8 +677,8 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -703,8 +703,8 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -729,8 +729,8 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -23,8 +23,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -40,4 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -23,7 +23,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
@@ -52,8 +53,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -61,7 +61,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / last_dim;
|
||||
const int h_bias = qkv_bias % last_dim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias;
|
||||
@@ -107,7 +107,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
@@ -130,8 +131,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -139,7 +139,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / last_dim;
|
||||
const int h_bias = qkv_bias % last_dim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t base_idx = token_idx * 3 * hidden_size +
|
||||
@@ -167,7 +167,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
@@ -199,8 +200,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -208,7 +208,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / half_lastdim;
|
||||
const int h_bias = qkv_bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int bias_idx_left =
|
||||
@@ -261,7 +261,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
@@ -285,8 +286,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -294,7 +294,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / half_lastdim;
|
||||
const int h_bias = qkv_bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int base_idx_left = token_idx * 3 * full_hidden_size +
|
||||
@@ -327,7 +327,8 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, q_num_head, dim_head]
|
||||
@@ -357,14 +358,13 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
@@ -410,7 +410,8 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
@@ -434,14 +435,13 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t base_idx =
|
||||
@@ -472,7 +472,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const float *qkv_out_scales,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const T *qkv_biases,
|
||||
@@ -504,15 +505,13 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
int ori_seq_id;
|
||||
ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
@@ -561,7 +560,8 @@ template <typename T, int VecSize = 1>
|
||||
__global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const T *qkv_biases,
|
||||
@@ -590,15 +590,13 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
int ori_seq_id;
|
||||
ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
@@ -645,7 +643,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, q_num_head, dim_head]
|
||||
@@ -676,14 +675,13 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / half_lastdim;
|
||||
const int h_bias = bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int bias_idx_left = hi * last_dim + h_bias;
|
||||
@@ -736,7 +734,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
@@ -761,14 +760,13 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / half_lastdim;
|
||||
const int h_bias = bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int base_idx_left =
|
||||
@@ -805,7 +803,8 @@ __global__ void cache_kernel(
|
||||
T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size]
|
||||
const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int *__restrict__ padding_offsets, // [num_tokens]
|
||||
const int *__restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int *__restrict__ cu_seqlens_q, // [bsz]
|
||||
const int *__restrict__ seq_lens, // [bsz]
|
||||
const int *__restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
@@ -831,11 +830,9 @@ __global__ void cache_kernel(
|
||||
const uint32_t qkv_bias = bias % hidden_size;
|
||||
const uint32_t hi = qkv_bias / head_size;
|
||||
const uint32_t h_bias = qkv_bias % head_size;
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int32_t *block_table_now = nullptr;
|
||||
|
||||
@@ -878,8 +875,8 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -909,15 +906,46 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// int lane_id = wid * 32 + tid;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
@@ -980,7 +1008,6 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
@@ -1118,8 +1145,8 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -1148,10 +1175,46 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
|
||||
const uint32_t HEAD_DIM_HALF = HEAD_DIM / 2;
|
||||
const uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2;
|
||||
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM_HALF / KV_VEC_SIZE; // 4
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; // 8
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM_HALF +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM_HALF]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE_HALF / KV_VEC_SIZE; // 2
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v; // 16
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE_HALF +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE_HALF]);
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T k_scale_smem[HEAD_DIM];
|
||||
@@ -1262,7 +1325,6 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t write_b_stride = HEAD_DIM / 2;
|
||||
@@ -1407,7 +1469,8 @@ void rotary_qk_variable(
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
const T *qkv_bias,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
@@ -1439,7 +1502,8 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1455,7 +1519,8 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
@@ -1473,7 +1538,8 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1489,7 +1555,8 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
@@ -1508,7 +1575,8 @@ void gqa_rotary_qk_variable(
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
const T *qkv_bias,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
@@ -1543,7 +1611,8 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1561,7 +1630,8 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
@@ -1581,7 +1651,8 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1598,7 +1669,8 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1622,7 +1694,8 @@ void gqa_rotary_qk_quant_variable(
|
||||
const T *cache_k_scales,
|
||||
const T *cache_v_scales,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
@@ -1654,7 +1727,8 @@ void gqa_rotary_qk_quant_variable(
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_bias,
|
||||
@@ -1673,7 +1747,8 @@ void gqa_rotary_qk_quant_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_bias,
|
||||
@@ -1699,7 +1774,8 @@ void CascadeAppendWriteCacheKVQKV(
|
||||
&qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const int max_seq_len,
|
||||
@@ -1725,7 +1801,8 @@ void CascadeAppendWriteCacheKVQKV(
|
||||
reinterpret_cast<T *>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<T *>(value_cache_out->data<T>()),
|
||||
block_table.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -1749,8 +1826,8 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
const paddle::Tensor &cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1814,8 +1891,8 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -1837,8 +1914,8 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
const paddle::Tensor &cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1884,8 +1961,8 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
|
@@ -25,8 +25,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
@@ -63,7 +63,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
@@ -82,7 +83,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
@@ -103,7 +105,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
|
||||
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
@@ -123,7 +126,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
CascadeAppendWriteCacheKVQKV<T>(meta_data,
|
||||
*qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
max_seq_len,
|
||||
@@ -142,8 +146,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cache_v_scale.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
@@ -169,8 +173,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
|
@@ -194,23 +194,26 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num)
|
||||
{
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = cum_offsets.shape()[0];
|
||||
auto max_len_tensor =
|
||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor, bsz);
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
|
||||
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
|
||||
// max_just_dec_merged_len_this_time, max_system_len,
|
||||
// max_just_dec_len_without_system
|
||||
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
auto max_len_cpu_ptr = max_len_cpu.data<int>();
|
||||
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor_gpu, bsz);
|
||||
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||
|
||||
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||
int max_len_this_time = max_len_cpu_ptr[0];
|
||||
int max_enc_len_this_time = max_len_cpu_ptr[1];
|
||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||
@@ -222,14 +225,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor decoder_batch_ids;
|
||||
paddle::Tensor decoder_tile_ids_per_batch;
|
||||
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
@@ -291,95 +291,64 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
}
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu =
|
||||
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu /*cpu*/,
|
||||
max_len_cpu};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_this_time_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype) {
|
||||
return {
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape) {
|
||||
std::vector<int64_t> dynamic_shape = {-1};
|
||||
|
||||
return {dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
{1},
|
||||
{8}};
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time",
|
||||
"cum_offsets"})
|
||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks"),
|
||||
paddle::Optional("decoder_batch_ids"),
|
||||
paddle::Optional("decoder_tile_ids_per_batch"),
|
||||
paddle::Optional("decoder_num_blocks"),
|
||||
paddle::Optional("max_len_kv"), "set_max_lengths"})
|
||||
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
|
||||
"group_size: int", "block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
|
||||
|
@@ -16,7 +16,6 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "encoder_write_cache_with_rope_impl.cuh"
|
||||
#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
@@ -25,7 +24,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_k,
|
||||
@@ -52,14 +52,13 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
@@ -108,9 +107,10 @@ void gqa_rotary_qk_split_variable(
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *padding_offsets,
|
||||
const int *batch_id_per_token,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_q,
|
||||
const int *cu_seqlens_k,
|
||||
const int token_num,
|
||||
const int num_heads,
|
||||
@@ -133,7 +133,8 @@ void gqa_rotary_qk_split_variable(
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
@@ -148,13 +149,188 @@ void gqa_rotary_qk_split_variable(
|
||||
dim_head);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS=4>
|
||||
__global__ void append_cache_kv_c16(
|
||||
const T *__restrict__ cache_k,
|
||||
const T *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
T *__restrict__ v_out,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ cu_seqlens_k,
|
||||
const int *__restrict__ block_tables,
|
||||
const int *batch_ids,
|
||||
const int *tile_ids_per_batch,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads) {
|
||||
// start_kv_idx: start kv_idx current block
|
||||
// batch_id:block's batch_id
|
||||
// TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8)
|
||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
|
||||
const uint32_t batch_id = batch_ids[tile_idx];
|
||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
|
||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
|
||||
if (seq_lens_this_time[batch_id] <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq;
|
||||
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE];
|
||||
// cache_kv idx
|
||||
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
uint32_t block_stride = kv_num_heads * kv_h_stride;
|
||||
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
|
||||
uint32_t kv_frag[4];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(kv_frag);
|
||||
|
||||
constexpr uint32_t num_vecs_per_head =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
smem_t k_smem(smem);
|
||||
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 16;
|
||||
k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
/***
|
||||
r0c0,r0c1, r0c8,r0c9
|
||||
r8c0,r8c1, r8c8,r8c9
|
||||
***/
|
||||
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
k_tile_ptr0[0] = frag_dq_T[0];
|
||||
k_tile_ptr0[1] = frag_dq_T[1];
|
||||
k_tile_ptr0[8] = frag_dq_T[2];
|
||||
k_tile_ptr0[9] = frag_dq_T[3];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
k_tile_ptr1[0] = frag_dq_T[4];
|
||||
k_tile_ptr1[1] = frag_dq_T[5];
|
||||
k_tile_ptr1[8] = frag_dq_T[6];
|
||||
k_tile_ptr1[9] = frag_dq_T[7];
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>(
|
||||
k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16;
|
||||
}
|
||||
|
||||
// ================v================
|
||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
|
||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy);
|
||||
v_read_idx += 8 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 16;
|
||||
v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
/***
|
||||
r0c0,r0c1, r0c8,r0c9
|
||||
r8c0,r8c1, r8c8,r8c9
|
||||
***/
|
||||
T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride;
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
v_tile_ptr0[0] = frag_dq_T[0];
|
||||
v_tile_ptr0[1] = frag_dq_T[1];
|
||||
v_tile_ptr0[8] = frag_dq_T[2];
|
||||
v_tile_ptr0[9] = frag_dq_T[3];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
v_tile_ptr1[0] = frag_dq_T[4];
|
||||
v_tile_ptr1[1] = frag_dq_T[5];
|
||||
v_tile_ptr1[8] = frag_dq_T[6];
|
||||
v_tile_ptr1[9] = frag_dq_T[7];
|
||||
}
|
||||
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>(
|
||||
v_smem_offset_r, fy);
|
||||
}
|
||||
v_smem_offset_r =
|
||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS=4,
|
||||
bool IS_FP8=false>
|
||||
__global__ void append_dequant_cache_kv_c8(
|
||||
__global__ void append_cache_kv_c8(
|
||||
const CacheT *__restrict__ cache_k,
|
||||
const CacheT *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
@@ -169,16 +345,16 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
const int *tile_ids_per_batch,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads) {
|
||||
// start_kv_idx: 每个block的起始kv_idx
|
||||
// batch_id:每个block属于的batch
|
||||
// TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8)
|
||||
// start_kv_idx: start kv_idx current block
|
||||
// batch_id:block's batch_id
|
||||
// TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8)
|
||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
|
||||
const uint32_t batch_id = batch_ids[tile_idx];
|
||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
|
||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
|
||||
if (seq_lens_this_time <= 0) {
|
||||
if (seq_lens_this_time[batch_id] <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -192,8 +368,8 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
|
||||
uint32_t k_frag[4], v_frag[4], frag_dq[4];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(frag_dq);
|
||||
@@ -214,13 +390,13 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
|
||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load k_smem 行是64 列是128
|
||||
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行
|
||||
for (int fy = 0; fy < 1; fy++) { // 一次8个128b = 128个uint8
|
||||
// load v_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -235,13 +411,13 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal k_smem 行是64 列是128
|
||||
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行
|
||||
// deal k_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 4; fy++) { // 1次2个128b(32个uint8),4次循环8个128b(128个uint8)
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
// 反量化 存储
|
||||
// layout
|
||||
/***
|
||||
r0c0,r0c1,r0c8,r0c9, r8c0,r8c1,r8c8,r8c9
|
||||
r0c16,r0c17,r0c24,r0c25, r8c16,r8c17,r8c24,r8c25
|
||||
@@ -251,8 +427,7 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
convert_c8<T,IS_FP8>(frag_dq_T,k_frag[2 * i]); // 4个uint8/fp8 -> 4个T
|
||||
|
||||
convert_c8<T,IS_FP8>(frag_dq_T,k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale;
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale;
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale;
|
||||
@@ -260,8 +435,7 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4,k_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
||||
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4,k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale;
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale;
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale;
|
||||
@@ -275,8 +449,8 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 8;
|
||||
}
|
||||
// ================v================
|
||||
|
||||
// ================v================
|
||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
|
||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
|
||||
wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp
|
||||
@@ -286,9 +460,9 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
|
||||
uint32_t v_read_idx = (wid * 8 + tid / 4) * BLOCK_SIZE +
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 行是128 列是64
|
||||
for (int fy = 0; fy < 4; fy++) { // 每个warp1次8行,循环4次32行,4个warp128行
|
||||
for (int fz = 0; fz < 1; fz++) { // 一次4个128b = 64个uint8
|
||||
// load v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -304,42 +478,32 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal v_smem 行是128 列是64 row_idx是head_dim, col_idx是block_size
|
||||
for (int fy = 0; fy < 2; fy++) { // 每个warp1次16行,循环2次32行,4个warp128行
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 2; fz++) { // 1次2个128b(32个uint8),2次循环4个128b(64个uint8)
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// 反量化 存储
|
||||
// layout
|
||||
for (int i = 0; i < 4 / 2; i++) {
|
||||
T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8;
|
||||
convert_c8<T,IS_FP8>(frag_dq_T, v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
|
||||
if (kv_idx < end_idx) {
|
||||
convert_c8<T,IS_FP8>(frag_dq_T, v_frag[2 * i]); // 4个uint8/fp8 -> 4个T
|
||||
#ifdef C8_DEBUG
|
||||
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
||||
printf("1.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n",
|
||||
fy, fz, kv_idx, dim_idx, static_cast<float>(frag_dq_T[0]), static_cast<float>(frag_dq_T[1]),
|
||||
static_cast<float>(frag_dq_T[2]), static_cast<float>(frag_dq_T[3]));
|
||||
}
|
||||
#endif
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale;
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale;
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
||||
|
||||
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
||||
#ifdef C8_DEBUG
|
||||
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
||||
printf("2.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n",
|
||||
fy, fz, kv_idx, dim_idx + 8, static_cast<float>(frag_dq_T[4]), static_cast<float>(frag_dq_T[5]),
|
||||
static_cast<float>(frag_dq_T[6]), static_cast<float>(frag_dq_T[7]));
|
||||
}
|
||||
#endif
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale;
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale;
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale;
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale;
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale;
|
||||
}
|
||||
kv_idx += 16;
|
||||
@@ -352,12 +516,250 @@ __global__ void append_dequant_cache_kv_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS=4>
|
||||
__global__ void append_cache_kv_c4(
|
||||
const CacheT *__restrict__ cache_k,
|
||||
const CacheT *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
T *__restrict__ v_out,
|
||||
const T *__restrict__ cache_k_dequant_scales,
|
||||
const T *__restrict__ cache_v_dequant_scales,
|
||||
const T *__restrict__ cache_k_zero_point,
|
||||
const T *__restrict__ cache_v_zero_point,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ cu_seqlens_k,
|
||||
const int *__restrict__ block_tables,
|
||||
const int *batch_ids,
|
||||
const int *tile_ids_per_batch,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads) {
|
||||
// start_kv_idx: start kv_idx current block
|
||||
// batch_id:block's batch_id
|
||||
// TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8)
|
||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
|
||||
const uint32_t batch_id = batch_ids[tile_idx];
|
||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
|
||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
|
||||
if (seq_lens_this_time[batch_id] <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq;
|
||||
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE];
|
||||
if (block_id < 0) block_id = 0;
|
||||
|
||||
constexpr uint32_t HEAD_DIM_HALF = HEAD_DIM / 2;
|
||||
constexpr uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2;
|
||||
// cache_kv idx
|
||||
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM_HALF;
|
||||
uint32_t block_stride = kv_num_heads * kv_h_stride;
|
||||
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
|
||||
uint32_t k_frag[4], v_frag[4], frag_dq[8];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(frag_dq);
|
||||
|
||||
// load dequant scales and zero points
|
||||
const T *cache_k_scale_now = cache_k_dequant_scales + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_now = cache_v_dequant_scales + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM;
|
||||
T *cache_k_scale_smem = reinterpret_cast<T *>(
|
||||
smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
|
||||
T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM;
|
||||
T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM;
|
||||
T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) {
|
||||
cache_k_scale_smem[i] = cache_k_scale_now[i];
|
||||
cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast<T>(136.f);
|
||||
cache_v_scale_smem[i] = cache_v_scale_now[i];
|
||||
cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast<T>(136.f);
|
||||
}
|
||||
|
||||
smem_t k_smem(smem);
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM_HALF / num_elems_per_128b<CacheT>(); // 2
|
||||
constexpr uint32_t num_vecs_per_blocksize =
|
||||
BLOCK_SIZE_HALF / num_elems_per_128b<CacheT>();
|
||||
constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; // 4
|
||||
constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize;
|
||||
|
||||
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||
wid * 8 + tid / 4, tid % 4); // 2(iter) * 4(warp) * 8 row per warp
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); //
|
||||
|
||||
uint32_t k_read_idx = (wid * 8 + tid / 4) * HEAD_DIM / 2 +
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<4, num_vecs_per_head_k>(k_smem_offset_w, fy);
|
||||
k_read_idx += 4 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_w) - 4;
|
||||
k_read_idx += 8 * NUM_WARPS * HEAD_DIM / 2 - 4 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
convert_int4(frag_dq_T, k_frag[2 * i]);
|
||||
convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]);
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
|
||||
k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
|
||||
k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
|
||||
k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
|
||||
k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
|
||||
k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
|
||||
k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
|
||||
k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
|
||||
k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
|
||||
k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
|
||||
k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
|
||||
k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
|
||||
k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
|
||||
k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
|
||||
k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
|
||||
}
|
||||
col_idx += 32;
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>(
|
||||
k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 4;
|
||||
}
|
||||
|
||||
// ================v================
|
||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT) / 2);
|
||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
|
||||
wid * 16 + tid / 2, tid % 2); // 4 * 8 per warp
|
||||
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
uint32_t v_read_idx = (wid * 16 + tid / 2) * BLOCK_SIZE_HALF +
|
||||
tid % 2 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 rows
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>(v_smem_offset_w, fz);
|
||||
v_read_idx += 2 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_w) - 2;
|
||||
v_read_idx += 16 * NUM_WARPS * BLOCK_SIZE_HALF - 2 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8;
|
||||
|
||||
convert_int4(frag_dq_T, v_frag[2 * i]);
|
||||
convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]);
|
||||
if (kv_idx < end_idx) {
|
||||
v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 16 < end_idx) {
|
||||
v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 17 < end_idx) {
|
||||
v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 24 < end_idx) {
|
||||
v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 25 < end_idx) {
|
||||
v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
}
|
||||
kv_idx += 32;
|
||||
}
|
||||
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
|
||||
v_smem_offset_r, fz);
|
||||
}
|
||||
v_smem_offset_r =
|
||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_r) - 2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
void AppendDequantCache(
|
||||
void AppendCacheKV(
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::Tensor &cache_k_dequant_scales,
|
||||
const paddle::Tensor &cache_v_dequant_scales,
|
||||
const paddle::Tensor &cache_k_zp,
|
||||
const paddle::Tensor &cache_v_zp,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &cu_seqlens_k,
|
||||
@@ -371,19 +773,41 @@ void AppendDequantCache(
|
||||
paddle::Tensor *k_out,
|
||||
paddle::Tensor *v_out,
|
||||
const cudaStream_t& stream
|
||||
) {
|
||||
) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
constexpr int NUM_WARPS = 4;
|
||||
int block_num = cache_num_blocks_x.data<int>()[0];
|
||||
dim3 grids(block_num, 1, kv_num_heads);
|
||||
dim3 blocks(32, NUM_WARPS);
|
||||
|
||||
constexpr int NUM_WARPS = 4;
|
||||
int block_num = cache_num_blocks_x.data<int>()[0];
|
||||
dim3 grids(block_num, 1, kv_num_heads);
|
||||
dim3 blocks(32, NUM_WARPS);
|
||||
if (cache_quant_type == "none") {
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2;
|
||||
auto kernel_func = append_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel_func,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
}
|
||||
kernel_func<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
block_tables.data<int>(),
|
||||
cache_batch_ids.data<int>(),
|
||||
cache_tile_ids_per_batch.data<int>(),
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads
|
||||
);
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
|
||||
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
||||
auto kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
|
||||
kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel_func,
|
||||
@@ -406,6 +830,34 @@ void AppendDequantCache(
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads
|
||||
);
|
||||
} else if (cache_quant_type == "cache_int4_zp") {
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T);
|
||||
|
||||
auto kernel_func = append_cache_kv_c4<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel_func,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
}
|
||||
kernel_func<<<grids, blocks, smem_size, stream>>>(
|
||||
cache_k.data<uint8_t>(),
|
||||
cache_v.data<uint8_t>(),
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_dequant_scales.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_dequant_scales.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_zp.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_zp.data<T>())),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
block_tables.data<int>(),
|
||||
cache_batch_ids.data<int>(),
|
||||
cache_tile_ids_per_batch.data<int>(),
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads
|
||||
);
|
||||
} else {
|
||||
PADDLE_THROW("%s mode isn't implemented yet", cache_quant_type.c_str());
|
||||
}
|
||||
@@ -421,8 +873,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids,
|
||||
@@ -450,9 +901,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int token_num = qkv_dims[0];
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
const int block_size = key_cache.dims()[2];
|
||||
const int batch_size = cum_offsets.dims()[0];
|
||||
const int batch_size = seq_lens_this_time.dims()[0];
|
||||
const int kv_num_heads = key_cache_dims[1];
|
||||
const int head_dim = key_cache_dims[3];
|
||||
const int head_dim = cache_quant_type == "cache_int4_zp" ? key_cache_dims[3] * 2 : key_cache_dims[3];
|
||||
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
|
||||
const float softmax_scale = 1.f / sqrt(head_dim);
|
||||
|
||||
@@ -463,7 +914,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
meta_data.q_num_heads = num_heads;
|
||||
meta_data.max_blocks_per_seq = max_blocks_per_seq;
|
||||
meta_data.block_size = block_size;
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
|
||||
|
||||
@@ -493,9 +944,10 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
padding_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
@@ -504,13 +956,38 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
block_tables,
|
||||
cache_batch_ids,
|
||||
cache_tile_ids,
|
||||
cache_num_blocks,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
cache_quant_type,
|
||||
&k,
|
||||
&v,
|
||||
stream
|
||||
);
|
||||
}
|
||||
// write cache
|
||||
if (cache_quant_type == "none") {
|
||||
CascadeAppendWriteCacheKVQKV<data_t>(
|
||||
meta_data,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
max_seq_len,
|
||||
@@ -527,8 +1004,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
cache_v_quant_scales.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
@@ -539,6 +1016,32 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int4_zp") {
|
||||
CascadeAppendWriteCacheKVC4QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
*const_cast<paddle::Tensor*>(&value_cache),
|
||||
qkv_out,
|
||||
cache_k_quant_scales.get(),
|
||||
cache_v_quant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
@@ -559,28 +1062,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendDequantCache<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
block_tables,
|
||||
cache_batch_ids,
|
||||
cache_tile_ids,
|
||||
cache_num_blocks,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
cache_quant_type,
|
||||
&k,
|
||||
&v,
|
||||
stream
|
||||
);
|
||||
}
|
||||
return {q, k, v, qkv_out};
|
||||
}
|
||||
|
||||
@@ -594,8 +1075,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"batch_id_per_token",
|
||||
"block_tables",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
|
292
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
292
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
@@ -0,0 +1,292 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mla_cache_kernel.cuh"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
|
||||
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
} else {
|
||||
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = seq_lens_encoder.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_STATIC_OP(decode_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
|
240
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
240
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -26,8 +26,8 @@ __global__ void append_clear_cache_int8_block(
|
||||
// block_size, head_size // 2]
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -41,10 +41,10 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -100,8 +100,8 @@ __global__ void append_clear_cache_int4_block(
|
||||
// block_size, head_size // 2]
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -115,10 +115,10 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -178,8 +178,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -214,12 +214,12 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -311,8 +311,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -347,12 +347,12 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / half_hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -458,8 +458,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -484,10 +484,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -690,8 +690,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -716,10 +716,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
|
||||
@@ -1068,8 +1068,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1097,10 +1097,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int lane_id = tid % 32;
|
||||
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -1130,6 +1130,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
LoadOutScaleT out_scale_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
#pragma unroll
|
||||
for (int v_i = 0; v_i < VecSize; v_i++) {
|
||||
bias_vec[v_i] = 0;
|
||||
}
|
||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
||||
#pragma unroll
|
||||
@@ -1137,8 +1141,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
@@ -1148,10 +1152,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
input_left = input_left * out_scale_vec[2 * i] +
|
||||
static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||
static_cast<float>(bias_vec[2 * i + 1]);
|
||||
// input_left = input_left * out_scale_vec[2 * i] +
|
||||
// static_cast<float>(bias_vec[2 * i]);
|
||||
// input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||
// static_cast<float>(bias_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
@@ -1167,6 +1171,35 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
||||
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||
block_size * half_head_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||
HeadDim * half_block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
@@ -1182,7 +1215,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
LoadScaleT zp_vec1, zp_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
#pragma unroll
|
||||
for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) {
|
||||
bias_vec1[v_i] = 0;
|
||||
bias_vec2[v_i] = 0;
|
||||
}
|
||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
//////////
|
||||
@@ -1191,11 +1228,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
/////
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||
&out_scale_vec2);
|
||||
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
@@ -1215,10 +1252,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
input_left =
|
||||
input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||
input_right =
|
||||
input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||
// input_left =
|
||||
// input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||
// input_right =
|
||||
// input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
@@ -1233,10 +1270,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
input_left =
|
||||
input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||
input_right =
|
||||
input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||
// input_left =
|
||||
// input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||
// input_right =
|
||||
// input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
@@ -1374,8 +1411,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1403,10 +1440,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int lane_id = tid % 32;
|
||||
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -1792,4 +1829,4 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -22,8 +22,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -59,8 +59,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -82,8 +82,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -106,8 +106,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -136,8 +136,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -151,8 +151,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -175,8 +175,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -201,8 +201,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -233,8 +233,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -248,8 +248,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -274,8 +274,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -301,8 +301,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -349,8 +349,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -376,8 +376,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -409,8 +409,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -442,8 +442,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -488,8 +488,8 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -514,8 +514,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -539,8 +539,8 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -566,8 +566,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -582,4 +582,4 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -23,8 +23,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -39,4 +39,4 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -37,8 +37,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,8 +37,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -38,8 +38,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -85,8 +85,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -80,8 +80,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,8 +82,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,8 +82,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -81,8 +81,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -81,8 +81,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -22,8 +22,8 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user