mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
Compare commits
120 Commits
fix-gpu-me
...
copilot/ad
Author | SHA1 | Date | |
---|---|---|---|
![]() |
017c82b993 | ||
![]() |
063ec680ff | ||
![]() |
98bfefea02 | ||
![]() |
c60adf4281 | ||
![]() |
bbd548ceb6 | ||
![]() |
f556561584 | ||
![]() |
a553d1896c | ||
![]() |
e31c8f7336 | ||
![]() |
de34222842 | ||
![]() |
8e8a5913da | ||
![]() |
9f0e2a6854 | ||
![]() |
30ddcc9115 | ||
![]() |
2359c8d21c | ||
![]() |
1dc1397ef6 | ||
![]() |
12326b60e1 | ||
![]() |
f12159b630 | ||
![]() |
08b3153661 | ||
![]() |
d00faeec69 | ||
![]() |
7e0bfd024f | ||
![]() |
1f056a7469 | ||
![]() |
319a4bf75f | ||
![]() |
f884cd4f62 | ||
![]() |
f32327661c | ||
![]() |
976aa88e66 | ||
![]() |
ed462cf238 | ||
![]() |
20495f927e | ||
![]() |
0c46318b34 | ||
![]() |
9ead10e1bc | ||
![]() |
571ddc677b | ||
![]() |
316ac546d3 | ||
![]() |
83bd55100b | ||
![]() |
aadd6a94d8 | ||
![]() |
2033450391 | ||
![]() |
ed5133f704 | ||
![]() |
17169a14f2 | ||
![]() |
3d0aaa5923 | ||
![]() |
472402bf4e | ||
![]() |
af49b81ffd | ||
![]() |
b5e20e3015 | ||
![]() |
7833f2f6cb | ||
![]() |
b649494655 | ||
![]() |
7c268693ed | ||
![]() |
e52ce1c4b1 | ||
![]() |
30a1c1783f | ||
![]() |
349aa6348b | ||
![]() |
0c45e225d3 | ||
![]() |
f6f726c773 | ||
![]() |
0d989829bb | ||
![]() |
bd7d15f7ea | ||
![]() |
2cf55168ca | ||
![]() |
41aee08982 | ||
![]() |
b23fc654d9 | ||
![]() |
ab1929f5ff | ||
![]() |
fc3bc56e59 | ||
![]() |
7643e6e6b2 | ||
![]() |
e0e7d68435 | ||
![]() |
4c160aa4dd | ||
![]() |
c7b7126b20 | ||
![]() |
29628de6a7 | ||
![]() |
ed97cf8396 | ||
![]() |
88d44a2c93 | ||
![]() |
f265a26f8b | ||
![]() |
f36a388ffe | ||
![]() |
22c165d6dd | ||
![]() |
e83251699f | ||
![]() |
ac46ef403a | ||
![]() |
0989788b29 | ||
![]() |
6ef3b611b0 | ||
![]() |
460809070c | ||
![]() |
7baf1b56e0 | ||
![]() |
9ec4fa0f8e | ||
![]() |
c870be6d27 | ||
![]() |
3790505319 | ||
![]() |
e24b745d48 | ||
![]() |
aaa2de1afa | ||
![]() |
abde903813 | ||
![]() |
7dbd9412b0 | ||
![]() |
fc598d4c5a | ||
![]() |
31313e0f3d | ||
![]() |
4c998c3636 | ||
![]() |
0a1ce612c2 | ||
![]() |
fa58a9fa8f | ||
![]() |
d22d3de256 | ||
![]() |
2527eb0e4e | ||
![]() |
54b458fd98 | ||
![]() |
d81c57146f | ||
![]() |
2396e49f9e | ||
![]() |
94a61d505c | ||
![]() |
ce998449e0 | ||
![]() |
f7a4bea785 | ||
![]() |
5441538173 | ||
![]() |
2c9b169c0e | ||
![]() |
e0c9a6c76c | ||
![]() |
0fe1d62232 | ||
![]() |
18e5d355a1 | ||
![]() |
8e1b35a09b | ||
![]() |
b6a4115369 | ||
![]() |
693c7d781c | ||
![]() |
aa067a3106 | ||
![]() |
7a521bbf62 | ||
![]() |
f296aff6cf | ||
![]() |
205b706ef8 | ||
![]() |
306c024ff3 | ||
![]() |
905d89e42f | ||
![]() |
1908465542 | ||
![]() |
0e4df5a6f4 | ||
![]() |
bf0cf5167a | ||
![]() |
7e751c93ae | ||
![]() |
27f2e7a6f1 | ||
![]() |
6ac7cea81b | ||
![]() |
adc246127b | ||
![]() |
6dd61a1bab | ||
![]() |
253f388372 | ||
![]() |
d6369b4d51 | ||
![]() |
0513a78ecc | ||
![]() |
0297127a93 | ||
![]() |
2bd7d90929 | ||
![]() |
6566e29807 | ||
![]() |
085fe070f2 | ||
![]() |
927e8ec55e |
4
.github/workflows/_accuracy_test.yml
vendored
4
.github/workflows/_accuracy_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
4
.github/workflows/_base_test.yml
vendored
4
.github/workflows/_base_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
3
.github/workflows/_build_linux.yml
vendored
3
.github/workflows/_build_linux.yml
vendored
@@ -134,7 +134,6 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
@@ -149,7 +148,7 @@ jobs:
|
||||
elif [[ "${PADDLEVERSION}" != "" ]];then
|
||||
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
else
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
fi
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
73
.github/workflows/_ci_image_build.yml
vendored
Normal file
73
.github/workflows/_ci_image_build.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: Docker Build
|
||||
description: "FastDeploy CI Image Build"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
CI_DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: false
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
outputs:
|
||||
docker_name_precheck:
|
||||
description: "Output path of the generated wheel"
|
||||
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
|
||||
|
||||
jobs:
|
||||
docker_build:
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
outputs:
|
||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
id: docker_build
|
||||
shell: bash
|
||||
env:
|
||||
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
# Docker Build
|
||||
cd tools/dockerfile/
|
||||
set -e
|
||||
cp ../../requirements.txt ./
|
||||
cp ../../scripts/unittest_requirement.txt ./
|
||||
docker build -t ${docker_image_name} -f Dockerfile.ci . \
|
||||
--network host \
|
||||
--no-cache
|
||||
docker push ${docker_image_name}
|
||||
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT
|
4
.github/workflows/_logprob_test_linux.yml
vendored
4
.github/workflows/_logprob_test_linux.yml
vendored
@@ -39,6 +39,7 @@ jobs:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
||||
run: |
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -116,7 +117,6 @@ jobs:
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
@@ -133,7 +133,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
4
.github/workflows/_pre_ce_test.yml
vendored
4
.github/workflows/_pre_ce_test.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install ${fd_wheel_url}
|
||||
bash scripts/run_pre_ce.sh
|
||||
'
|
||||
|
4
.github/workflows/_stable_test.yml
vendored
4
.github/workflows/_stable_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
9
.github/workflows/_unit_test_coverage.yml
vendored
9
.github/workflows/_unit_test_coverage.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -168,13 +168,10 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install pytest-cov
|
||||
python -m pip install jsonschema aistudio_sdk==0.3.5
|
||||
python -m pip install -r scripts/unittest_requirement.txt
|
||||
python -m pip install ${fd_wheel_url}
|
||||
rm -rf fastdeploy
|
||||
# coverage subprocess use
|
||||
|
18
.github/workflows/ce_job.yml
vendored
18
.github/workflows/ce_job.yml
vendored
@@ -199,13 +199,13 @@ jobs:
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
||||
ce_upload_sm8689:
|
||||
@@ -224,9 +224,9 @@ jobs:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
@@ -238,11 +238,11 @@ jobs:
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
174
.github/workflows/ci_image_update.yml
vendored
Normal file
174
.github/workflows/ci_image_update.yml
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
name: CI Images Build
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
||||
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
ci_image_build:
|
||||
name: CI Images Build
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_ci_image_build.yml
|
||||
with:
|
||||
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, ci_image_build]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
|
||||
publish_pre_check:
|
||||
name: Publish Docker Images Pre Check
|
||||
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
steps:
|
||||
- name: Images Uploading
|
||||
env:
|
||||
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
run: |
|
||||
echo "images_name=${images_name}"
|
||||
docker images ${ci_image_name}
|
||||
docker tag ${images_name} ${ci_image_name}
|
||||
docker push ${ci_image_name}
|
2
.github/workflows/pr_build_and_test.yml
vendored
2
.github/workflows/pr_build_and_test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "89,90"
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
FD_VERSION: "0.0.0"
|
||||
|
||||
|
10
.github/workflows/publish_job.yml
vendored
10
.github/workflows/publish_job.yml
vendored
@@ -319,3 +319,13 @@ jobs:
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
9
.gitmodules
vendored
9
.gitmodules
vendored
@@ -1,9 +0,0 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
@@ -59,7 +59,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
|
@@ -57,7 +57,7 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
|
@@ -14,7 +14,7 @@
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void set_value_by_flag_and_id(const bool *stop_flags,
|
||||
void set_value_by_flags_and_idx(const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
|
||||
set_value_by_flag_and_id(stop_flags.data<bool>(),
|
||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
|
@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
void UpdateInputs(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"input_ids", "input_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
||||
|
@@ -140,8 +140,8 @@ void AppendAttentionKernel(
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
@@ -273,11 +273,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -296,11 +300,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -309,7 +317,6 @@ void AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
@@ -336,7 +343,6 @@ void AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
|
@@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -142,7 +148,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -511,7 +523,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -902,6 +914,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -960,6 +973,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -173,7 +179,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -635,7 +647,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -32,14 +32,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -57,6 +58,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -86,33 +88,40 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
|
||||
const uint32_t q_end =
|
||||
@@ -180,7 +189,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -201,6 +210,13 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + num_frags_z * 16;
|
||||
}
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -282,10 +298,22 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -318,6 +346,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const int ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -336,6 +365,18 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -346,7 +387,9 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -463,14 +506,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise=false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -489,6 +533,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -518,32 +563,39 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
@@ -609,7 +661,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -634,6 +686,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
@@ -716,11 +775,23 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -753,6 +824,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const uint32_t ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -771,6 +843,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -781,7 +865,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -895,7 +981,8 @@ template <typename T,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
void MultiQueryAppendC8Attention(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
@@ -953,7 +1040,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
|
||||
constexpr uint32_t smem_size =
|
||||
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -970,7 +1058,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -988,7 +1078,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1022,7 +1114,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1040,7 +1134,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1075,6 +1171,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1133,6 +1230,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1218,7 +1316,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||
constexpr uint32_t smem_size =
|
||||
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1235,7 +1334,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1253,7 +1354,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1295,7 +1398,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1313,7 +1418,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1350,6 +1457,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1424,6 +1532,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1546,6 +1655,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
@@ -1554,6 +1664,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim = meta_data.head_dims;
|
||||
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
|
||||
|
||||
DISPATCH_CAUSAL(
|
||||
causal,
|
||||
@@ -1572,43 +1683,46 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SIZE,
|
||||
{DISPATCH_BLOCKSHAPE_Q(
|
||||
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL, IsFP8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})
|
||||
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})})
|
||||
}
|
||||
|
@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t num_warps,
|
||||
uint32_t block_size,
|
||||
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
@@ -860,20 +968,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -929,7 +1044,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
} else {
|
||||
out_of_boundary =
|
||||
(causal
|
||||
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1135,16 +1252,28 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1215,16 +1346,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
|
@@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
|
@@ -18,6 +18,53 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
|
||||
// Note(ZKK)
|
||||
// This function is very easy!
|
||||
// just make HeadDim data to be new HeadDim data!
|
||||
|
||||
template <typename T, int VecSize=8, int HEAD_DIM=128, int NUM_THREADS=32>
|
||||
__device__ __forceinline__ void apply_rope(
|
||||
const T* input,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* output,
|
||||
const int thread_id) {
|
||||
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) {
|
||||
Load<T, VecSize>(&input[head_bias], &src_vec);
|
||||
const uint32_t emb_idx = head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &output[head_bias]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -28,7 +75,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -120,7 +167,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
@@ -129,6 +175,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
@@ -164,7 +211,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -270,7 +317,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -381,142 +428,6 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
if (seq_lens_encoder[ori_bi] > 0) continue;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -527,7 +438,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -641,7 +551,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -765,6 +674,293 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
T* __restrict__ cache_k_scale,
|
||||
T* __restrict__ cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + bid * max_blocks_per_seq;
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&key_cache[tgt_idx + block_i * HeadDim]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
||||
using LoadEmbT = AlignedVector<float, 1>;
|
||||
LoadKVResT cache_vec;
|
||||
LoadT src_vec1, src_vec2;
|
||||
LoadBiasT out_vec1, out_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
}
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
}
|
||||
|
||||
// rope
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
out_vec2[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
}
|
||||
if (k_norm_weight) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
||||
const uint32_t tgt_cache_idx =
|
||||
block_idx * kv_num_heads * block_size * HeadDim +
|
||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * kv_num_heads * HeadDim * block_size +
|
||||
kv_head_idx * HeadDim * block_size +
|
||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
||||
block_offset % 8 / 2 * 4 // per 4
|
||||
+ block_offset % 16 / 8 * 2 // per 2
|
||||
+ block_offset % 2; // per 1
|
||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -775,7 +971,6 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -813,44 +1008,18 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
@@ -1025,7 +1194,6 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1330,7 +1498,6 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1632,7 +1799,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2029,7 +2196,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2070,44 +2237,18 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
@@ -2327,7 +2468,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2658,7 +2799,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -3031,7 +3172,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
|
@@ -21,7 +21,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -59,7 +58,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -84,7 +82,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -97,7 +94,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -121,7 +117,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -138,8 +133,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -154,33 +148,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -191,7 +162,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -214,7 +184,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -238,7 +207,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -271,7 +239,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -297,7 +264,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -323,7 +289,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -349,7 +314,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -375,7 +339,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -410,7 +373,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -438,7 +400,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -466,7 +427,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -494,7 +454,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -521,7 +480,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -558,20 +516,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
@@ -582,7 +531,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -605,9 +553,39 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -617,7 +595,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -632,7 +609,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
@@ -650,7 +626,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -683,7 +658,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -717,7 +691,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -743,6 +716,36 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
@@ -750,7 +753,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -798,7 +800,6 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -828,7 +829,6 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -857,7 +857,6 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -886,7 +885,6 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
|
@@ -23,7 +23,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
|
@@ -449,8 +449,8 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize;
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
@@ -900,74 +900,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
@@ -1300,6 +1232,411 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS,
|
||||
bool is_need_kv_quant,
|
||||
bool IsFP8 = true>
|
||||
__global__ void append_write_cache_kv_c8_qkv_dynamic(
|
||||
uint8_t *__restrict__ cache_k,
|
||||
uint8_t *__restrict__ cache_v,
|
||||
const T *__restrict__ qkv_input,
|
||||
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
|
||||
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t pad_len = BLOCK_SIZE;
|
||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const T cache_k_scale = cache_k_scales[kv_head_idx];
|
||||
const T cache_v_scale = cache_v_scales[kv_head_idx];
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
const uint32_t batch_id = batch_ids[btid];
|
||||
const uint32_t tile_id = tile_ids[btid];
|
||||
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
|
||||
if (seq_len_this_time <= 0) {
|
||||
return;
|
||||
}
|
||||
const int *block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + batch_id * max_blocks_per_seq;
|
||||
|
||||
const uint32_t num_rows_per_block =
|
||||
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
|
||||
const uint32_t start_len = seq_lens_decoder[batch_id];
|
||||
const uint32_t bf_pad_len = start_len % pad_len;
|
||||
const uint32_t start_len_pad = start_len - bf_pad_len;
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_scale_smem[BLOCK_SIZE];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
|
||||
|
||||
/*
|
||||
0 | 1
|
||||
2 | 3
|
||||
*/
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
|
||||
/*
|
||||
0 | 2
|
||||
1 | 3
|
||||
*/
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
tid % 16, wid * num_frags_v * 2 + tid / 16);
|
||||
|
||||
// load kv gmem to smem
|
||||
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
|
||||
tile_id * num_rows_per_block +
|
||||
wid * num_frags_z * 16 + tid / 8;
|
||||
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++j) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y / 4;
|
||||
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
|
||||
if (chunk_start >= start_len && chunk_start < end_len) {
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<T>();
|
||||
v_read_idx += 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
|
||||
2 * num_frags_y;
|
||||
chunk_start += 4;
|
||||
k_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
v_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// reduce scale
|
||||
// 16 rows per warp
|
||||
uint32_t kv_reduce_frag[4];
|
||||
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
|
||||
|
||||
T k_local_max_value[num_frags_z * 2];
|
||||
T v_local_max_value[num_frags_z * 2];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
k_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
v_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
const int num_kv_heads = gridDim.z;
|
||||
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
|
||||
T *cache_k_scale_now = cache_k_scales + scale_offset;
|
||||
T *cache_v_scale_now = cache_v_scales + scale_offset;
|
||||
// k scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
// used for quant
|
||||
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
// v scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
|
||||
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now] = 0;
|
||||
v_scale_smem[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
|
||||
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now + 8] = 0;
|
||||
v_scale_smem[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// mask, quant, store
|
||||
using LoadKVT = AlignedVector<uint8_t, 4>;
|
||||
LoadKVT cache_vec1;
|
||||
LoadKVT cache_vec2;
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
const uint32_t write_d_stride = BLOCK_SIZE;
|
||||
uint32_t k_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t k_write_idx_now = k_write_idx_now_z +
|
||||
fy % 2 * 8 * write_b_stride +
|
||||
fy / 2 * 32; // + fy % 2 * 16;
|
||||
// load
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
|
||||
chunk_start_k + (v_id / 4) * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id - 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
|
||||
2 * num_frags_y;
|
||||
chunk_start_k += 16;
|
||||
}
|
||||
|
||||
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
|
||||
uint32_t v_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
|
||||
T v_scales[num_frags_z_v * 4];
|
||||
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
|
||||
const int offset = v_i * 16;
|
||||
const int t_offset = tid % 4 * 2;
|
||||
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
|
||||
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
|
||||
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
|
||||
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
|
||||
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
|
||||
uint32_t v_write_idx_now = v_write_idx_now_v +
|
||||
fz % 2 * 8 * write_d_stride +
|
||||
fz / 2 * 32; // + fz % 2 * 16;
|
||||
// load
|
||||
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
|
||||
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
|
||||
// store now
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
|
||||
chunk_start_v += 16;
|
||||
v_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
|
||||
}
|
||||
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
|
||||
v_smem_offset_r, wid * num_frags_v + fy) -
|
||||
16 * num_frags_z_v * num_vecs_per_head;
|
||||
chunk_start_v -= 16 * num_frags_z_v;
|
||||
}
|
||||
}
|
||||
|
||||
// Write Cache KV in Append
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
@@ -1823,7 +2160,6 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -1904,38 +2240,7 @@ void gqa_rotary_qk_variable(
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -1953,7 +2258,6 @@ void gqa_rotary_qk_variable(
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2107,10 +2411,11 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
int num_blocks_x_cpu,
|
||||
int max_seq_len,
|
||||
bool is_scale_channel_wise,
|
||||
const bool is_fp8,
|
||||
const std::string& cache_quant_type,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *cache_k_out,
|
||||
paddle::Tensor *cache_v_out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
@@ -2128,49 +2433,77 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (is_fp8) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
if (cache_quant_type != "block_wise_fp8") {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
} else {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
|
@@ -55,19 +55,9 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
@@ -135,7 +125,6 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
@@ -178,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
DISPATCH_HEAD_DIM(
|
||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||
@@ -198,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
is_scale_channel_wise,
|
||||
cache_quant_type_str == "cache_fp8",
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
|
@@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -223,13 +230,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
@@ -237,17 +238,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(), bsz);
|
||||
|
||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
@@ -258,16 +256,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
|
||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
// Clear buffer
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
@@ -275,21 +269,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu =
|
||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
@@ -314,15 +294,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
|
@@ -217,7 +217,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -235,7 +235,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -278,7 +278,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -296,7 +296,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -400,7 +400,7 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
// load v_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -418,7 +418,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal k_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
|
||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
// layout
|
||||
@@ -466,7 +466,7 @@ __global__ void append_cache_kv_c8(
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -485,7 +485,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
|
||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -614,7 +614,7 @@ __global__ void append_cache_kv_c4(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -632,7 +632,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
|
||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
|
||||
@@ -685,7 +685,7 @@ __global__ void append_cache_kv_c4(
|
||||
tid % 2 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 rows
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -704,7 +704,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
false, // is_scale_channel_wise
|
||||
cache_quant_type == "cache_fp8", // is_fp8
|
||||
cache_quant_type,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
|
@@ -18,6 +18,166 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const float*
|
||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int output_inner_dim,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
LoadInT src_vec;
|
||||
LoadFloat scale_vec;
|
||||
LoadT bias_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec;
|
||||
LoadFloat k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
const int write_q_idx =
|
||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
if (qkv_out_scales) {
|
||||
input_left *= scale_vec[2 * i];
|
||||
input_right *= scale_vec[2 * i + 1];
|
||||
}
|
||||
if (qkv_biases) {
|
||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
||||
} else {
|
||||
// write k/v
|
||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size +
|
||||
block_offset * head_size + h_bias);
|
||||
// write
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -193,7 +353,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -253,8 +414,9 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -326,7 +488,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -390,8 +553,9 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -476,7 +640,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -522,8 +687,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -583,10 +749,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
@@ -708,7 +875,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -757,8 +925,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||
&left_out_scale_vec);
|
||||
@@ -853,10 +1022,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1088,7 +1258,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1145,8 +1316,9 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1235,10 +1407,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -1431,7 +1604,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1581,10 +1755,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||
|
@@ -15,6 +15,77 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
@@ -39,7 +110,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -73,7 +145,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
@@ -96,7 +169,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +199,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -167,7 +242,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -191,7 +267,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +299,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -266,7 +344,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -292,7 +371,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
@@ -313,11 +393,15 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out) {
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -342,142 +426,184 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,11 +626,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -526,11 +656,15 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -551,11 +685,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
|
||||
template void
|
||||
@@ -578,8 +716,12 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -56,6 +56,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -103,5 +104,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -98,5 +99,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -441,6 +441,15 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
||||
if (is_dynamic_cfp8) { \
|
||||
constexpr bool IsDynamicC8 = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool IsDynamicC8 = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
|
@@ -255,7 +255,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums);
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
@@ -298,7 +299,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||
const int layer_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
@@ -306,6 +307,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -378,9 +386,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size);
|
||||
|
||||
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
@@ -707,6 +717,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
@@ -750,6 +776,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -763,7 +790,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
@@ -980,7 +1008,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
"per token per block quant and padding tranpose scale");
|
||||
"per token per block quant and padding transpose scale");
|
||||
|
||||
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
|
||||
py::arg("recv_expert_count"), py::arg("block_size"),
|
||||
@@ -1023,7 +1051,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_ffn_wint2.cu
|
||||
* moe/fused_moe/moe_expert_ffn_wint2.cu
|
||||
* moe_expert_ffn_wint2
|
||||
*/
|
||||
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
|
||||
@@ -1228,6 +1256,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
@@ -89,11 +89,11 @@ public:
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of warp-level GEMM oeprations per load for B
|
||||
/// Number of warp-level GEMM operations per load for B
|
||||
static constexpr int kWarpGemmIterationsPerLoadForB =
|
||||
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
|
||||
|
@@ -117,7 +117,7 @@ class LeftGELUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to interal compute numeric type
|
||||
// Convert source to internal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to interal compute numeric type
|
||||
// Convert source to internal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -92,7 +92,7 @@ class DualMmaBase {
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);
|
||||
|
||||
|
@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
|
||||
// adding elementwise beta, we keep this here for future useage beta_ =
|
||||
// adding elementwise beta, we keep this here for future usage beta_ =
|
||||
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
|
||||
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
|
||||
// iterator_C_.clear_mask();
|
||||
|
@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
@@ -64,7 +64,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
|
@@ -133,7 +133,7 @@ public:
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
|
@@ -509,7 +509,7 @@ public:
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
// TOOD(wangbojun) lds_converter can be remove for int8 B input
|
||||
// TODO(wangbojun) lds_converter can be remove for int8 B input
|
||||
typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
|
||||
|
@@ -96,7 +96,7 @@ public:
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
|
@@ -646,7 +646,7 @@ public:
|
||||
// );
|
||||
// }
|
||||
}
|
||||
// TOOD(wangbojun) lds_converter can be remove for int8 B input
|
||||
// TODO(wangbojun) lds_converter can be remove for int8 B input
|
||||
// int4
|
||||
// typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
// lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
|
@@ -171,7 +171,7 @@ struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
@@ -30,12 +30,10 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@@ -65,8 +63,6 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -51,8 +51,6 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -383,7 +383,7 @@ __global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attenti
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
inline __device__ float calculate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
|
||||
const int32_t bi = blockIdx.z;
|
||||
@@ -524,7 +524,7 @@ __global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_a
|
||||
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
|
||||
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
|
||||
|
||||
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
float inv_global_exp_sum = calculate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
|
||||
|
||||
using T_vec = Vec<cuteType, kNReducePacksize>;
|
||||
|
@@ -40,7 +40,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
if (seq_len == 0) return;
|
||||
|
||||
const int ramian_tokens = seq_len - block_idx;
|
||||
const int remain_tokens = seq_len - block_idx;
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
|
||||
@@ -51,7 +51,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
@@ -50,14 +50,14 @@ __global__ void get_kv_from_cache_c16_kernel(
|
||||
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
|
||||
|
||||
|
||||
const int ramian_tokens = seq_len - base_token_idx;
|
||||
const int remain_tokens = seq_len - base_token_idx;
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
@@ -67,7 +67,7 @@ __global__ void get_kv_from_cache_c16_kernel(
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
@@ -872,16 +872,14 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream) {
|
||||
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
||||
|
||||
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
|
||||
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
|
||||
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
|
||||
constexpr int kThreads = 128;
|
||||
if (FLAGS_hardamard_use_diagonal_block_matrix) {
|
||||
const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1
|
||||
const int VecSize = hadamard_block_size / kThreads;
|
||||
const int logN = int(ceil(std::log2(kThreads * VecSize)));
|
||||
constexpr int kNChunks = 1;
|
||||
DISPATCH_SP_VS(VecSize, VEC_SIZE, {
|
||||
@@ -991,6 +989,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::float16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1009,6 +1008,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1027,6 +1027,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::bfloat16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1045,6 +1046,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
|
@@ -32,5 +32,6 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream);
|
||||
|
@@ -80,7 +80,7 @@ void MoeDispatchKernel(
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill sucess ?)
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
@@ -35,7 +35,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -291,6 +292,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
|
||||
stream
|
||||
);
|
||||
@@ -340,6 +342,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream
|
||||
);
|
||||
@@ -403,7 +406,7 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
|
||||
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
@@ -424,7 +427,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -439,7 +443,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
@@ -458,7 +463,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
@@ -470,7 +476,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums)};
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
@@ -485,7 +492,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
@@ -499,7 +507,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
@@ -555,6 +563,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
* Affects activation function implementation
|
||||
* - estimate_total_token_nums: estimate total token numbers
|
||||
* - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization
|
||||
*
|
||||
* Note:
|
||||
* - w4a8 mode requires additional workspace memory allocation
|
||||
@@ -571,7 +581,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||
|
@@ -15,31 +15,72 @@
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void recover_spec_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
int64_t *draft_tokens,
|
||||
const int64_t *step_draft_tokens,
|
||||
const int *step_seq_lens_this_time,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
const int draft_tokens_len,
|
||||
const int num_extra_tokens) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
|
||||
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
|
||||
if (block_table_now[max_possible_block_idx] != -1) {
|
||||
// can be recovered for decoding
|
||||
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
|
||||
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
|
||||
draft_tokens_now[i] = step_draft_tokens_now[i];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
@@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
if (draft_tokens) {
|
||||
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
|
||||
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
|
||||
step_draft_tokens.get_ptr()->data<int64_t>(),
|
||||
step_seq_lens_this_time.get_ptr()->data<int>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
max_draft_tokens * 2 + 1);
|
||||
|
||||
} else {
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
@@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"block_tables",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
"is_block_step",
|
||||
paddle::Optional("draft_tokens"),
|
||||
paddle::Optional("step_draft_tokens"),
|
||||
paddle::Optional("step_seq_lens_this_time")})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
|
@@ -75,7 +75,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -45,7 +45,7 @@ void save_kernel(const paddle::Tensor& x,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -34,7 +34,7 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags,
|
||||
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
|
||||
const int seq_len_dec = seq_lens_decoder[tid];
|
||||
const int seq_len_enc = seq_lens_encoder[tid];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
|
||||
if (step_idx[tid] >= 0) {
|
||||
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
|
||||
pre_ids_all_now[step_idx[tid]] = input_ids_now[seq_len_enc - 1];
|
@@ -15,7 +15,48 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
|
||||
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
|
||||
do { \
|
||||
constexpr int BlockSize = BLOCK_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
|
||||
do { \
|
||||
if (truncate_first_token) { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
|
||||
do { \
|
||||
if (kvcache_scheduler_v1) { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
|
||||
do { \
|
||||
if (splitwise_prefill) { \
|
||||
constexpr bool SPLITWISE_PREFILL = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool SPLITWISE_PREFILL = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
// 1. process block_step situation
|
||||
// -- In v0 mode, block_step will drop mtp query.
|
||||
// -- In v1 mode, block_step will continue to infer.
|
||||
if constexpr(KVCACHE_SCHEDULER_V1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. process normal query, not in any special case.
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
// prefill generation
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
} else { // decode generation
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
// TODO: check
|
||||
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
void DispatchRunner(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -291,6 +261,7 @@ void DispatchTokenMode(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -310,75 +281,79 @@ void DispatchTokenMode(
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
DISPATCH_BLOCKSIZE(512, {
|
||||
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
|
||||
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
|
||||
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
|
||||
if constexpr (SPLITWISE_PREFILL) {
|
||||
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
@@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
DispatchRunner(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
|
@@ -63,7 +63,7 @@ __global__ void ComputeOrderKernel(
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
in_offset += cur_base_model_seq_lens_this_time - accept_num;
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// (liuzichang): Temporary Reserved for debug
|
||||
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
@@ -139,6 +139,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
||||
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||
.Inputs({"input_ids",
|
||||
"draft_tokens",
|
||||
"cum_offsets"
|
||||
"token_num",
|
||||
"seq_len",
|
||||
"seq_lens_encoder"})
|
||||
|
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_schedula_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq
|
||||
);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"step_draft_tokens",
|
||||
"step_seq_lens_this_time",
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"step_draft_tokens_out",
|
||||
"step_seq_lens_this_time_out",
|
||||
"accept_num_out",
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
@@ -35,7 +35,7 @@ __global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
|
||||
accept_tokens + tid * max_draft_tokens;
|
||||
const int seq_len_dec = seq_lens_decoder[tid];
|
||||
const int seq_len_enc = seq_lens_encoder[tid];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
|
||||
// printf("step_idx[tid] %d\n", step_idx[tid]);
|
||||
if (step_idx[tid] >= 0) {
|
||||
for (int i = 0; i < accept_num[tid]; i++) {
|
@@ -295,7 +295,7 @@ void SpeculateStepSchedule(const paddle::Tensor &stop_flags,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -283,7 +283,7 @@ void Schedule(const paddle::Tensor &stop_flags,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -58,7 +58,7 @@ class TokenTransfer {
|
||||
}
|
||||
|
||||
// once copy: cpu --> cpu
|
||||
// arrary length should be (1 + MAX_BATCH)
|
||||
// array length should be (1 + MAX_BATCH)
|
||||
bool GetBatchToken(int64_t *array) {
|
||||
if (Empty()) {
|
||||
return false;
|
||||
|
@@ -1,71 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "cuda_multiprocess.h"
|
||||
|
||||
#if !defined(_WIN32)
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
|
||||
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd)
|
||||
static inline int sharedMemoryUnlinkByName(const char* name) {
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
|
||||
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
|
||||
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
|
||||
if (hMap) {
|
||||
CloseHandle(hMap);
|
||||
return 0;
|
||||
}
|
||||
// 已经不存在也算成功
|
||||
return 0;
|
||||
#else
|
||||
// POSIX: 移除名字,未来不可再 open;已映射区仍存活直至 munmap
|
||||
if (shm_unlink(name) != 0) {
|
||||
if (errno == ENOENT) return 0; // 不存在视作成功
|
||||
return errno;
|
||||
}
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void UnsetDataIpc(const paddle::Tensor& tmp_input,
|
||||
const std::string& shm_name,
|
||||
bool close_ipc,
|
||||
bool unlink_shm) {
|
||||
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
|
||||
if (close_ipc) {
|
||||
void* ptr = const_cast<void*>(tmp_input.data());
|
||||
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
|
||||
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
|
||||
if (unlink_shm) {
|
||||
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
|
||||
if (rc != 0) {
|
||||
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
|
||||
shm_name.c_str(), rc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(unset_data_ipc)
|
||||
.Inputs({"tmp_input"})
|
||||
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
|
||||
.SetKernelFn(PD_KERNEL(UnsetDataIpc));
|
@@ -75,10 +75,10 @@ void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
|
||||
const int max_seq_len,
|
||||
const int max_batch_size,
|
||||
const int split_fuse_size) {
|
||||
dim3 girds;
|
||||
girds.x = max_batch_size;
|
||||
dim3 grids;
|
||||
grids.x = max_batch_size;
|
||||
const int block_size = 128;
|
||||
update_split_fuse_inputs_kernel<<<girds,
|
||||
update_split_fuse_inputs_kernel<<<grids,
|
||||
block_size,
|
||||
0,
|
||||
input_ids.stream()>>>(
|
||||
|
@@ -110,7 +110,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill sucess ?)
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
@@ -37,52 +37,6 @@ def load_module_from_path(module_name, path):
|
||||
return module
|
||||
|
||||
|
||||
def update_git_repo():
|
||||
try:
|
||||
print("update third party repo...", flush=True)
|
||||
original_dir = os.getcwd()
|
||||
submodule_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
third_party_path = os.path.join(submodule_dir, "third_party")
|
||||
root_path = Path(third_party_path)
|
||||
|
||||
# check if third_party is empty
|
||||
update_third_party = False
|
||||
for dirpath in root_path.iterdir():
|
||||
if dirpath.is_dir():
|
||||
has_content = any(dirpath.iterdir())
|
||||
if not has_content:
|
||||
update_third_party = True
|
||||
|
||||
if update_third_party:
|
||||
os.chdir(submodule_dir)
|
||||
subprocess.run(
|
||||
"git submodule sync --recursive && git submodule update --init --recursive",
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# apply deep gemm patch
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
patch_source = os.path.join(submodule_dir, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
if not os.path.exists(patch_destination):
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
os.chdir(dst_path)
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(original_dir)
|
||||
except subprocess.CalledProcessError:
|
||||
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
# cannot import envs directly because it depends on fastdeploy,
|
||||
@@ -92,8 +46,6 @@ envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.
|
||||
archs = json.loads(envs.FD_BUILDING_ARCS)
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
update_git_repo()
|
||||
|
||||
|
||||
def download_and_extract(url, destination_directory):
|
||||
"""
|
||||
@@ -126,6 +78,52 @@ def download_and_extract(url, destination_directory):
|
||||
print(f"Error extracting file: {e}")
|
||||
|
||||
|
||||
def clone_git_repo(version, repo_url, destination_path):
|
||||
"""
|
||||
Clone git repo to destination path.
|
||||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"-b",
|
||||
version,
|
||||
"--single-branch",
|
||||
repo_url,
|
||||
destination_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
|
||||
"""
|
||||
reset git repo to destination commit and apply patch.
|
||||
"""
|
||||
if commit_id is not None:
|
||||
reset_cmd = ["git", "reset", "--hard", commit_id]
|
||||
if patch is not None:
|
||||
patch_source = os.path.join(cur_path, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
|
||||
try:
|
||||
os.chdir(dst_path)
|
||||
if commit_id is not None:
|
||||
subprocess.run(reset_cmd, check=True)
|
||||
if patch is not None:
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(cur_path)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def get_sm_version(archs):
|
||||
"""
|
||||
Get sm version of paddle.
|
||||
@@ -193,13 +191,20 @@ def find_end_files(directory, end_str):
|
||||
if paddle.is_compiled_with_rocm():
|
||||
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
||||
# so we need to check if paddle compiled with rocm at first.
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
"gpu_ops/get_output_msg_with_topk.cc",
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/token_penalty_multi_scores.cu",
|
||||
"gpu_ops/stop_generation.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
@@ -208,7 +213,6 @@ if paddle.is_compiled_with_rocm():
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/step_system_cache.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -219,7 +223,7 @@ if paddle.is_compiled_with_rocm():
|
||||
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_save_output.cc",
|
||||
"gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_step.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
|
||||
@@ -257,7 +261,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/set_mask_value.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/ngram_mask.cu",
|
||||
"gpu_ops/gather_idx.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -272,14 +276,13 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/step_reschedule.cu",
|
||||
"gpu_ops/fused_get_rope.cu",
|
||||
"gpu_ops/fused_get_rotary_embedding.cu",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/update_inputs_beam.cu",
|
||||
"gpu_ops/beam_search_softmax.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/read_data_ipc.cu",
|
||||
"gpu_ops/enforce_generation.cu",
|
||||
"gpu_ops/dequant_int8.cu",
|
||||
@@ -313,6 +316,28 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
|
||||
]
|
||||
|
||||
cutlass_dir = "third_party/cutlass"
|
||||
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
||||
if not os.path.exists(cutlass_dir):
|
||||
os.makedirs(cutlass_dir)
|
||||
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
|
||||
if not os.listdir(cutlass_dir):
|
||||
raise ValueError("Git clone cutlass failed!")
|
||||
|
||||
# deep gemm
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
||||
if not os.path.exists(deep_gemm_dir):
|
||||
os.makedirs(deep_gemm_dir)
|
||||
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
|
||||
if not os.listdir(deep_gemm_dir):
|
||||
raise ValueError("Git clone DeepGEMM failed!")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dst_path = os.path.join(cur_path, deep_gemm_dir)
|
||||
commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
process_git_repo(cur_path, dst_path, commit_id, patch)
|
||||
|
||||
dg_third_party_include_dirs = (
|
||||
"third_party/cutlass/include/cute",
|
||||
"third_party/cutlass/include/cutlass",
|
||||
@@ -340,6 +365,14 @@ elif paddle.is_compiled_with_cuda():
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
||||
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
|
||||
cc_compile_args = []
|
||||
nvcc_compile_args = get_gencode_flags(archs)
|
||||
nvcc_compile_args += ["-DPADDLE_DEV"]
|
||||
@@ -474,7 +507,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
||||
|
||||
if cc >= 90 and nvcc_version >= 12.0:
|
||||
# Hopper optmized mla
|
||||
# Hopper optimized mla
|
||||
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
||||
sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
|
||||
sources += find_end_files("gpu_ops/moba_attn/moba_decoder_attn/", ".cu")
|
||||
@@ -527,7 +560,7 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
@@ -560,6 +593,13 @@ elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
)
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/update_inputs_v1.cu",
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
@@ -569,7 +609,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/save_with_output.cc",
|
||||
"gpu_ops/set_mask_value.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/ngram_mask.cu",
|
||||
"gpu_ops/gather_idx.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -578,7 +618,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/stop_generation.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
"gpu_ops/set_flags.cu",
|
||||
"gpu_ops/fused_get_rope.cu",
|
||||
"gpu_ops/fused_get_rotary_embedding.cu",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/update_inputs_beam.cu",
|
||||
|
1
custom_ops/third_party/DeepGEMM
vendored
1
custom_ops/third_party/DeepGEMM
vendored
Submodule custom_ops/third_party/DeepGEMM deleted from 95e81b3dd6
1
custom_ops/third_party/cutlass
vendored
1
custom_ops/third_party/cutlass
vendored
Submodule custom_ops/third_party/cutlass deleted from afa1772203
1
custom_ops/third_party/nlohmann_json
vendored
1
custom_ops/third_party/nlohmann_json
vendored
Submodule custom_ops/third_party/nlohmann_json deleted from 9cca280a4d
@@ -67,7 +67,7 @@ std::vector<paddle::Tensor> MoeLayerKernel(
|
||||
const auto xtype = x.dtype();
|
||||
auto x_dims = x.shape();
|
||||
auto up_gate_proj_dims = up_gate_proj_weight.shape();
|
||||
PD_CHECK(x_dims.size() == 2, "x_dims.size() shoud be 2.");
|
||||
PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2.");
|
||||
PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3.");
|
||||
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support.");
|
||||
if (quant_method == "weight_only_int4") {
|
||||
|
@@ -122,7 +122,7 @@ void SpeculateStepSchedule(
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -59,7 +59,7 @@ void SaveOutMmsg(const paddle::Tensor &x, const paddle::Tensor ¬_need_stop,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -4,7 +4,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_LM_SIZE 28672
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define ALIGNMENT 64
|
||||
|
@@ -4,7 +4,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_LM_SIZE 28672
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define ALIGNMENT 64
|
||||
|
@@ -8,7 +8,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_SM_SIZE 32768
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define BANK_CONFLICT_M 128
|
||||
|
@@ -79,7 +79,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
||||
# print("wscale_pd:\n{}".format(wscale_pd))
|
||||
# print("wscale_np:\n{}".format(wscale_np))
|
||||
|
||||
# comparation
|
||||
# comparison
|
||||
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
|
||||
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
|
||||
print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
|
||||
|
96
docs/best_practices/ERNIE-4.5-21B-A3B-Thinking.md
Normal file
96
docs/best_practices/ERNIE-4.5-21B-A3B-Thinking.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# ERNIE-4.5-21B-A3B
|
||||
## Environmental Preparation
|
||||
### 1.1 Hardware requirements
|
||||
The minimum number of GPUs required to deploy `ERNIE-4.5-21B-A3B` on the following hardware for each quantization is as follows:
|
||||
|
||||
| | WINT8 | WINT4 | FP8 |
|
||||
|-----|-----|-----|-----|
|
||||
|H800 80GB| 1 | 1 | 1 |
|
||||
|A800 80GB| 1 | 1 | / |
|
||||
|H20 96GB| 1 | 1 | 1 |
|
||||
|L20 48GB| 1 | 1 | 1 |
|
||||
|A30 40GB| 2 | 1 | / |
|
||||
|
||||
**Tips:**
|
||||
1. To modify the number of deployment GPUs, specify `--tensor-parallel-size 2` in starting command.
|
||||
2. For hardware not listed in the table, you can estimate whether it can be deployed based on the GPU memory.
|
||||
3. ERNIE-4.5-21B-A3B-Thinking requires FastDeploy version >= 2.2.
|
||||
|
||||
### 1.2 Install fastdeploy and prepare the model
|
||||
- Installation: For detail, please refer to [Fastdeploy Installation](../get_started/installation/README.md).
|
||||
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md).
|
||||
|
||||
## 2.How to Use
|
||||
### 2.1 Basic: Launching the Service
|
||||
Start the service by following command:
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-21B-A3B-Thinking \
|
||||
--load_choices "default_v1" \
|
||||
--tensor-parallel-size 1 \
|
||||
--max-model-len 131072 \
|
||||
--quantization wint8 \
|
||||
--reasoning-parser ernie_x1 \
|
||||
--tool-call-parser ernie_x1 \
|
||||
--max-num-seqs 32
|
||||
```
|
||||
- `--quantization`: Indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
|
||||
- `--max-model-len`: Indicates the maximum number of tokens supported by the currently deployed service. The larger the value, the longer the context length the model can support, but the more GPU memory is occupied, which may affect the concurrency.
|
||||
- `--load_choices`: Indicates the version of the loader. "default_v1" means enabling the v1 version of the loader, which has faster loading speed and less memory usage.
|
||||
- `--reasoning-parser`, `--tool-call-parser`: Indicates the corresponding reasoning content and tool call parser.
|
||||
|
||||
For more parameter meanings and default settings, see [FastDeploy Parameter Documentation](../parameters.md)。
|
||||
|
||||
### 2.2 Advanced: How to get better performance
|
||||
#### 2.2.1 Correctly set parameters that match the application scenario
|
||||
Evaluate average input length, average output length, and maximum context length
|
||||
- Set max-model-len according to the maximum context length. For example, if the average input length is 2000 and the output length is 80000, then it is recommended to set it to 131072
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**Idea:** The core idea of Prefix Caching is to avoid repeated calculations by caching the intermediate calculation results of the input sequence (KV Cache), thereby speeding up the response speed of multiple requests with the same prefix. For details, refer to [prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**How to enable:**
|
||||
Since version 2.2 (including the develop branch), Prefix Caching has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine. The recommended value is `(total machine memory - model size) * 20%`. If the service fails to start because other programs are occupying memory, try reducing the `--swap-space` value.
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
```
|
||||
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**Idea:** This strategy is adopted to split the prefill stage request into small-scale sub-chunks, and execute them in batches mixed with the decode request. This can better balance the computation-intensive (Prefill) and memory-intensive (Decode) operations, optimize GPU resource utilization, reduce the computational workload and memory usage of a single Prefill, thereby reducing the peak memory usage and avoiding the problem of insufficient memory. For details, please refer to [Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**How to enable:**
|
||||
Since version 2.2 (including the develop branch), Chunked Prefill has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
|
||||
#### 2.2.4 CUDAGraph
|
||||
**Idea:**
|
||||
CUDAGraph is a GPU computing acceleration technology provided by NVIDIA. It achieves efficient execution and optimization of GPU tasks by capturing CUDA operation sequences into a graph structure. The core idea of CUDAGraph is to encapsulate a series of GPU computing and memory operations into a re-executable graph, thereby reducing CPU-GPU communication overhead, reducing kernel startup latency, and improving overall computing performance.
|
||||
|
||||
**How to enable:**
|
||||
Add the following lines to the startup parameters
|
||||
```
|
||||
--use-cudagraph
|
||||
```
|
||||
Notes:
|
||||
- Usually, no additional parameters need to be set, but CUDAGraph will generate some additional memory overhead, which may need to be adjusted in some scenarios with limited memory. For detailed parameter adjustments, please refer to [GraphOptimizationBackend](../features/graph_optimization.md) for related configuration parameter descriptions
|
||||
|
||||
#### 2.2.5 Rejection Sampling
|
||||
**Idea:**
|
||||
Rejection sampling is to generate samples from a proposal distribution that is easy to sample, avoiding explicit sorting to increase the sampling speed, which has a significant improvement on small-sized models.
|
||||
|
||||
**How to enable:**
|
||||
Add the following environment variables before starting
|
||||
```
|
||||
export FD_SAMPLING_CLASS=rejection
|
||||
```
|
||||
|
||||
## FAQ
|
||||
If you encounter any problems during use, you can refer to [FAQ](./FAQ.md).
|
@@ -3,5 +3,6 @@
|
||||
- [ERNIE-4.5-0.3B-Paddle.md](ERNIE-4.5-0.3B-Paddle.md)
|
||||
- [ERNIE-4.5-21B-A3B-Paddle.md](ERNIE-4.5-21B-A3B-Paddle.md)
|
||||
- [ERNIE-4.5-300B-A47B-Paddle.md](ERNIE-4.5-300B-A47B-Paddle.md)
|
||||
- [ERNIE-4.5-21B-A3B-Thinking.md](ERNIE-4.5-21B-A3B-Thinking.md)
|
||||
- [ERNIE-4.5-VL-28B-A3B-Paddle](ERNIE-4.5-VL-28B-A3B-Paddle.md)
|
||||
- [ERNIE-4.5-VL-424B-A47B-Paddle](ERNIE-4.5-VL-424B-A47B-Paddle.md)
|
||||
|
@@ -196,7 +196,7 @@ We selected a subset (longbook_sum_eng) from InfiniteBench as the performance ev
|
||||
## Usage
|
||||
|
||||
```
|
||||
export FD_ATTENTION_BACKEND="PLAS_ATTN"
|
||||
export FD_ATTENTION_BACKEND="MOBA_ATTN"
|
||||
|
||||
python -m fastdeploy.entrypoints.openai.api_server
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
@@ -207,13 +207,13 @@ python -m fastdeploy.entrypoints.openai.api_server
|
||||
--max-num-batched-tokens 8192 \
|
||||
--max-model-len 131072 \
|
||||
--max-num-seqs 32 \
|
||||
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
||||
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
|
||||
```
|
||||
|
||||
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
|
||||
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
|
||||
|
||||
**Parameter Description:**
|
||||
|
||||
* Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention.
|
||||
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
|
||||
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
|
||||
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention.
|
||||
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
|
||||
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
|
||||
|
@@ -17,6 +17,7 @@
|
||||
|ERNIE-4.5-300B-A47B-Base|BF16/WINT4/WINT8|✅|✅|✅|⛔|✅|128K|
|
||||
|ERNIE-4.5-VL-424B-A47B|BF16/WINT4/WINT8|🚧|✅|🚧|⛔|🚧|128K|
|
||||
|ERNIE-4.5-VL-28B-A3B|BF16/WINT4/WINT8|⛔|✅|🚧|⛔|🚧|128K|
|
||||
|ERNIE-4.5-21B-A3B-Thinking|BF16/WINT4/WINT8/FP8|⛔|✅|✅|✅|✅|128K|
|
||||
|ERNIE-4.5-21B-A3B|BF16/WINT4/WINT8/FP8|⛔|✅|✅|✅|✅|128K|
|
||||
|ERNIE-4.5-21B-A3B-Base|BF16/WINT4/WINT8/FP8|⛔|✅|✅|⛔|✅|128K|
|
||||
|ERNIE-4.5-0.3B|BF16/WINT8/FP8|⛔|✅|✅|⛔|✅|128K|
|
||||
@@ -33,11 +34,11 @@
|
||||
|
||||
## Supported Hardware
|
||||
|
||||
| Model | [NVIDIA GPU](./get_started/installation/nvidia_gpu.md) |[Kunlunxin XPU](./get_started/installation/kunlunxin_xpu.md) | Ascend NPU | [Hygon DCU](./get_started/installation/hygon_dcu.md) | [Iluvatar GPU](./get_started/installation/iluvatar_gpu.md) | [MetaX GPU](./get_started/installation/metax_gpu.md.md) | [Enflame GCU](./get_started/installation/Enflame_gcu.md) |
|
||||
| Model | [NVIDIA GPU](./get_started/installation/nvidia_gpu.md) |[Kunlunxin XPU](./get_started/installation/kunlunxin_xpu.md) | Ascend NPU | [Hygon DCU](./get_started/installation/hygon_dcu.md) | [Iluvatar GPU](./get_started/installation/iluvatar_gpu.md) | [MetaX GPU](./get_started/installation/metax_gpu.md) | [Enflame GCU](./get_started/installation/Enflame_gcu.md) |
|
||||
|:------|---------|------------|----------|-------------|-----------|-------------|-------------|
|
||||
| ERNIE4.5-VL-424B-A47B | ✅ | 🚧 | 🚧 | ⛔ | ⛔ | ⛔ | ⛔ |
|
||||
| ERNIE4.5-300B-A47B | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | ✅ |
|
||||
| ERNIE4.5-VL-28B-A3B | ✅ | 🚧 | 🚧 | ⛔ | 🚧 | ⛔ | ⛔ |
|
||||
| ERNIE4.5-300B-A47B | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | ✅ |
|
||||
| ERNIE4.5-VL-28B-A3B | ✅ | 🚧 | 🚧 | ⛔ | 🚧 | 🚧 | ⛔ |
|
||||
| ERNIE4.5-21B-A3B | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | ✅ |
|
||||
| ERNIE4.5-0.3B | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
|
@@ -33,7 +33,7 @@ These models accept text input.
|
||||
|
||||
|Models|DataType|Example HF Model|
|
||||
|-|-|-|
|
||||
|⭐ERNIE|BF16\WINT4\WINT8\W4A8C8\WINT2\FP8|baidu/ERNIE-4.5-VL-424B-A47B-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-Paddle<br> [quick start](./get_started/ernie-4.5.md)   [best practice](./best_practices/ERNIE-4.5-300B-A47B-Paddle.md);<br>baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-FP8-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-Base-Paddle;<br>[baidu/ERNIE-4.5-21B-A3B-Paddle](./best_practices/ERNIE-4.5-21B-A3B-Paddle.md);<br>baidu/ERNIE-4.5-21B-A3B-Base-Paddle;<br>baidu/ERNIE-4.5-0.3B-Paddle<br> [quick start](./get_started/quick_start.md)   [best practice](./best_practices/ERNIE-4.5-0.3B-Paddle.md);<br>baidu/ERNIE-4.5-0.3B-Base-Paddle, etc.|
|
||||
|⭐ERNIE|BF16\WINT4\WINT8\W4A8C8\WINT2\FP8|baidu/ERNIE-4.5-VL-424B-A47B-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-Paddle<br> [quick start](./get_started/ernie-4.5.md)   [best practice](./best_practices/ERNIE-4.5-300B-A47B-Paddle.md);<br>baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-FP8-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-Base-Paddle;<br>[baidu/ERNIE-4.5-21B-A3B-Paddle](./best_practices/ERNIE-4.5-21B-A3B-Paddle.md);<br>baidu/ERNIE-4.5-21B-A3B-Base-Paddle;<br>baidu/ERNIE-4.5-21B-A3B-Thinking;<br>baidu/ERNIE-4.5-0.3B-Paddle<br> [quick start](./get_started/quick_start.md)   [best practice](./best_practices/ERNIE-4.5-0.3B-Paddle.md);<br>baidu/ERNIE-4.5-0.3B-Base-Paddle, etc.|
|
||||
|⭐QWEN3-MOE|BF16/WINT4/WINT8/FP8|Qwen/Qwen3-235B-A22B;<br>Qwen/Qwen3-30B-A3B, etc.|
|
||||
|⭐QWEN3|BF16/WINT8/FP8|Qwen/qwen3-32B;<br>Qwen/qwen3-14B;<br>Qwen/qwen3-8B;<br>Qwen/qwen3-4B;<br>Qwen/qwen3-1.7B;<br>[Qwen/qwen3-0.6B](./get_started/quick_start_qwen.md), etc.|
|
||||
|⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;<br>Qwen/qwen2.5-32B;<br>Qwen/qwen2.5-14B;<br>Qwen/qwen2.5-7B;<br>Qwen/qwen2.5-3B;<br>Qwen/qwen2.5-1.5B;<br>Qwen/qwen2.5-0.5B, etc.|
|
||||
|
97
docs/zh/best_practices/ERNIE-4.5-21B-A3B-Thinking.md
Normal file
97
docs/zh/best_practices/ERNIE-4.5-21B-A3B-Thinking.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# ERNIE-4.5-21B-A3B-Thinking
|
||||
## 一、环境准备
|
||||
### 1.1 支持情况
|
||||
ERNIE-4.5-21B-A3B 各量化精度,在下列硬件上部署所需要的最小卡数如下:
|
||||
|
||||
| | WINT8 | WINT4 | FP8 |
|
||||
|-----|-----|-----|-----|
|
||||
|H800 80GB| 1 | 1 | 1 |
|
||||
|A800 80GB| 1 | 1 | / |
|
||||
|H20 96GB| 1 | 1 | 1 |
|
||||
|L20 48GB| 1 | 1 | 1 |
|
||||
|A30 40GB| 2 | 1 | / |
|
||||
|
||||
**注:**
|
||||
1. 在启动命令后指定`--tensor-parallel-size 2` 即可修改部署卡数
|
||||
2. 表格中未列出的硬件,可根据显存大小进行预估是否可以部署
|
||||
3. ERNIE-4.5-21B-A3B-Thinking 需要**FastDeploy 2.2**及以上版本支持
|
||||
|
||||
### 1.2 安装fastdeploy
|
||||
- 安装,请参考[Fastdeploy Installation](../get_started/installation/README.md)完成安装。
|
||||
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。
|
||||
|
||||
## 二、如何使用
|
||||
### 2.1 基础:启动服务
|
||||
通过下列命令启动服务
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-21B-A3B-Thinking \
|
||||
--load_choices "default_v1" \
|
||||
--tensor-parallel-size 1 \
|
||||
--max-model-len 131072 \
|
||||
--quantization wint8 \
|
||||
--reasoning-parser ernie_x1 \
|
||||
--tool-call-parser ernie_x1 \
|
||||
--max-num-seqs 32
|
||||
```
|
||||
其中:
|
||||
- `--quantization`: 表示模型采用的量化策略。不同量化策略,模型的性能和精度也会不同。可选值包括:`wint8` / `wint4` / `block_wise_fp8`(需要Hopper架构)。
|
||||
- `--max-model-len`:表示当前部署的服务所支持的最长Token数量。设置得越大,模型可支持的上下文长度也越大,但相应占用的显存也越多,可能影响并发数。
|
||||
- `--load_choices`: 表示loader的版本,"default_v1"表示启用v1版本的loader,具有更快的加载速度和更少的内存使用。
|
||||
- `--reasoning-parser` 、 `--tool-call-parser`: 表示对应调用的思考内容和工具调用解析器
|
||||
|
||||
更多的参数含义与默认设置,请参见[FastDeploy参数说明](../parameters.md)。
|
||||
|
||||
### 2.2 进阶:如何获取更优性能
|
||||
#### 2.2.1 评估应用场景,正确设置参数
|
||||
结合应用场景,评估平均输入长度、平均输出长度、最大上下文长度。例如,平均输入长度为2000,输出长度为80000,那么建议设置为 32768
|
||||
- 根据最大上下文长度,设置`max-model-len`
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**原理:** Prefix Caching的核心思想是通过缓存输入序列的中间计算结果(KV Cache),避免重复计算,从而加速具有相同前缀的多个请求的响应速度。具体参考[prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**启用方式:**
|
||||
自2.2版本开始(包括develop分支),Prefix Caching已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
```
|
||||
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**原理:** 采用分块策略,将预填充(Prefill)阶段请求拆解为小规模子任务,与解码(Decode)请求混合批处理执行。可以更好地平衡计算密集型(Prefill)和访存密集型(Decode)操作,优化GPU资源利用率,减少单次Prefill的计算量和显存占用,从而降低显存峰值,避免显存不足的问题。 具体请参考[Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**启用方式:**
|
||||
自2.2版本开始(包括develop分支),Chunked Prefill已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
|
||||
#### 2.2.4 CUDAGraph
|
||||
**原理:**
|
||||
CUDAGraph 是 NVIDIA 提供的一项 GPU 计算加速技术,通过将 CUDA 操作序列捕获(capture)为图结构(graph),实现 GPU 任务的高效执行和优化。CUDAGraph 的核心思想是将一系列 GPU 计算和内存操作封装为一个可重复执行的图,从而减少 CPU-GPU 通信开销、降低内核启动延迟,并提升整体计算性能。
|
||||
|
||||
**启用方式:**
|
||||
在启动命令中增加
|
||||
```
|
||||
--use-cudagraph
|
||||
```
|
||||
注:
|
||||
- 通常情况下不需要额外设置其他参数,但CUDAGraph会产生一些额外的显存开销,在一些显存受限的场景下可能需要调整。详细的参数调整请参考[GraphOptimizationBackend](../features/graph_optimization.md) 相关配置参数说明
|
||||
|
||||
#### 2.2.5 拒绝采样
|
||||
**原理:**
|
||||
拒绝采样即从一个易于采样的提议分布(proposal distribution)中生成样本,避免显式排序从而达到提升采样速度的效果,对小尺寸的模型有较明显的提升。
|
||||
|
||||
**启用方式:**
|
||||
启动前增加下列环境变量
|
||||
```
|
||||
export FD_SAMPLING_CLASS=rejection
|
||||
```
|
||||
|
||||
## 三、常见问题FAQ
|
||||
如果您在使用过程中遇到问题,可以在[FAQ](./FAQ.md)中查阅。
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user