mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-02 15:22:24 +08:00
Compare commits
120 Commits
release/2.
...
copilot/ad
Author | SHA1 | Date | |
---|---|---|---|
![]() |
017c82b993 | ||
![]() |
063ec680ff | ||
![]() |
98bfefea02 | ||
![]() |
c60adf4281 | ||
![]() |
bbd548ceb6 | ||
![]() |
f556561584 | ||
![]() |
a553d1896c | ||
![]() |
e31c8f7336 | ||
![]() |
de34222842 | ||
![]() |
8e8a5913da | ||
![]() |
9f0e2a6854 | ||
![]() |
30ddcc9115 | ||
![]() |
2359c8d21c | ||
![]() |
1dc1397ef6 | ||
![]() |
12326b60e1 | ||
![]() |
f12159b630 | ||
![]() |
08b3153661 | ||
![]() |
d00faeec69 | ||
![]() |
7e0bfd024f | ||
![]() |
1f056a7469 | ||
![]() |
319a4bf75f | ||
![]() |
f884cd4f62 | ||
![]() |
f32327661c | ||
![]() |
976aa88e66 | ||
![]() |
ed462cf238 | ||
![]() |
20495f927e | ||
![]() |
0c46318b34 | ||
![]() |
9ead10e1bc | ||
![]() |
571ddc677b | ||
![]() |
316ac546d3 | ||
![]() |
83bd55100b | ||
![]() |
aadd6a94d8 | ||
![]() |
2033450391 | ||
![]() |
ed5133f704 | ||
![]() |
17169a14f2 | ||
![]() |
3d0aaa5923 | ||
![]() |
472402bf4e | ||
![]() |
af49b81ffd | ||
![]() |
b5e20e3015 | ||
![]() |
7833f2f6cb | ||
![]() |
b649494655 | ||
![]() |
7c268693ed | ||
![]() |
e52ce1c4b1 | ||
![]() |
30a1c1783f | ||
![]() |
349aa6348b | ||
![]() |
0c45e225d3 | ||
![]() |
f6f726c773 | ||
![]() |
0d989829bb | ||
![]() |
bd7d15f7ea | ||
![]() |
2cf55168ca | ||
![]() |
41aee08982 | ||
![]() |
b23fc654d9 | ||
![]() |
ab1929f5ff | ||
![]() |
fc3bc56e59 | ||
![]() |
7643e6e6b2 | ||
![]() |
e0e7d68435 | ||
![]() |
4c160aa4dd | ||
![]() |
c7b7126b20 | ||
![]() |
29628de6a7 | ||
![]() |
ed97cf8396 | ||
![]() |
88d44a2c93 | ||
![]() |
f265a26f8b | ||
![]() |
f36a388ffe | ||
![]() |
22c165d6dd | ||
![]() |
e83251699f | ||
![]() |
ac46ef403a | ||
![]() |
0989788b29 | ||
![]() |
6ef3b611b0 | ||
![]() |
460809070c | ||
![]() |
7baf1b56e0 | ||
![]() |
9ec4fa0f8e | ||
![]() |
c870be6d27 | ||
![]() |
3790505319 | ||
![]() |
e24b745d48 | ||
![]() |
aaa2de1afa | ||
![]() |
abde903813 | ||
![]() |
7dbd9412b0 | ||
![]() |
fc598d4c5a | ||
![]() |
31313e0f3d | ||
![]() |
4c998c3636 | ||
![]() |
0a1ce612c2 | ||
![]() |
fa58a9fa8f | ||
![]() |
d22d3de256 | ||
![]() |
2527eb0e4e | ||
![]() |
54b458fd98 | ||
![]() |
d81c57146f | ||
![]() |
2396e49f9e | ||
![]() |
94a61d505c | ||
![]() |
ce998449e0 | ||
![]() |
f7a4bea785 | ||
![]() |
5441538173 | ||
![]() |
2c9b169c0e | ||
![]() |
e0c9a6c76c | ||
![]() |
0fe1d62232 | ||
![]() |
18e5d355a1 | ||
![]() |
8e1b35a09b | ||
![]() |
b6a4115369 | ||
![]() |
693c7d781c | ||
![]() |
aa067a3106 | ||
![]() |
7a521bbf62 | ||
![]() |
f296aff6cf | ||
![]() |
205b706ef8 | ||
![]() |
306c024ff3 | ||
![]() |
905d89e42f | ||
![]() |
1908465542 | ||
![]() |
0e4df5a6f4 | ||
![]() |
bf0cf5167a | ||
![]() |
7e751c93ae | ||
![]() |
27f2e7a6f1 | ||
![]() |
6ac7cea81b | ||
![]() |
adc246127b | ||
![]() |
6dd61a1bab | ||
![]() |
253f388372 | ||
![]() |
d6369b4d51 | ||
![]() |
0513a78ecc | ||
![]() |
0297127a93 | ||
![]() |
2bd7d90929 | ||
![]() |
6566e29807 | ||
![]() |
085fe070f2 | ||
![]() |
927e8ec55e |
4
.github/workflows/_accuracy_test.yml
vendored
4
.github/workflows/_accuracy_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
4
.github/workflows/_base_test.yml
vendored
4
.github/workflows/_base_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
3
.github/workflows/_build_linux.yml
vendored
3
.github/workflows/_build_linux.yml
vendored
@@ -134,7 +134,6 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
@@ -149,7 +148,7 @@ jobs:
|
||||
elif [[ "${PADDLEVERSION}" != "" ]];then
|
||||
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
else
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
fi
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
73
.github/workflows/_ci_image_build.yml
vendored
Normal file
73
.github/workflows/_ci_image_build.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: Docker Build
|
||||
description: "FastDeploy CI Image Build"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
CI_DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: false
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
outputs:
|
||||
docker_name_precheck:
|
||||
description: "Output path of the generated wheel"
|
||||
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
|
||||
|
||||
jobs:
|
||||
docker_build:
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
outputs:
|
||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
id: docker_build
|
||||
shell: bash
|
||||
env:
|
||||
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
# Docker Build
|
||||
cd tools/dockerfile/
|
||||
set -e
|
||||
cp ../../requirements.txt ./
|
||||
cp ../../scripts/unittest_requirement.txt ./
|
||||
docker build -t ${docker_image_name} -f Dockerfile.ci . \
|
||||
--network host \
|
||||
--no-cache
|
||||
docker push ${docker_image_name}
|
||||
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT
|
4
.github/workflows/_logprob_test_linux.yml
vendored
4
.github/workflows/_logprob_test_linux.yml
vendored
@@ -39,6 +39,7 @@ jobs:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
||||
run: |
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -116,7 +117,6 @@ jobs:
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
@@ -133,7 +133,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
4
.github/workflows/_pre_ce_test.yml
vendored
4
.github/workflows/_pre_ce_test.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install ${fd_wheel_url}
|
||||
bash scripts/run_pre_ce.sh
|
||||
'
|
||||
|
4
.github/workflows/_stable_test.yml
vendored
4
.github/workflows/_stable_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
9
.github/workflows/_unit_test_coverage.yml
vendored
9
.github/workflows/_unit_test_coverage.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -168,13 +168,10 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install pytest-cov
|
||||
python -m pip install jsonschema aistudio_sdk==0.3.5
|
||||
python -m pip install -r scripts/unittest_requirement.txt
|
||||
python -m pip install ${fd_wheel_url}
|
||||
rm -rf fastdeploy
|
||||
# coverage subprocess use
|
||||
|
18
.github/workflows/ce_job.yml
vendored
18
.github/workflows/ce_job.yml
vendored
@@ -199,13 +199,13 @@ jobs:
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
||||
ce_upload_sm8689:
|
||||
@@ -224,9 +224,9 @@ jobs:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
@@ -238,11 +238,11 @@ jobs:
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
174
.github/workflows/ci_image_update.yml
vendored
Normal file
174
.github/workflows/ci_image_update.yml
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
name: CI Images Build
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
||||
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
ci_image_build:
|
||||
name: CI Images Build
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_ci_image_build.yml
|
||||
with:
|
||||
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, ci_image_build]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
|
||||
publish_pre_check:
|
||||
name: Publish Docker Images Pre Check
|
||||
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
steps:
|
||||
- name: Images Uploading
|
||||
env:
|
||||
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
run: |
|
||||
echo "images_name=${images_name}"
|
||||
docker images ${ci_image_name}
|
||||
docker tag ${images_name} ${ci_image_name}
|
||||
docker push ${ci_image_name}
|
2
.github/workflows/pr_build_and_test.yml
vendored
2
.github/workflows/pr_build_and_test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "89,90"
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
FD_VERSION: "0.0.0"
|
||||
|
||||
|
10
.github/workflows/publish_job.yml
vendored
10
.github/workflows/publish_job.yml
vendored
@@ -319,3 +319,13 @@ jobs:
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
9
.gitmodules
vendored
9
.gitmodules
vendored
@@ -1,9 +0,0 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
@@ -59,7 +59,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
|
@@ -57,7 +57,7 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
|
@@ -14,7 +14,7 @@
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void set_value_by_flag_and_id(const bool *stop_flags,
|
||||
void set_value_by_flags_and_idx(const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
|
||||
set_value_by_flag_and_id(stop_flags.data<bool>(),
|
||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
|
@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
void UpdateInputs(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"input_ids", "input_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
||||
|
@@ -140,8 +140,8 @@ void AppendAttentionKernel(
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
@@ -273,11 +273,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -296,11 +300,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -309,7 +317,6 @@ void AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
@@ -336,7 +343,6 @@ void AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
|
@@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -142,7 +148,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -511,7 +523,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -902,6 +914,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -960,6 +973,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -173,7 +179,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -635,7 +647,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -32,14 +32,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -57,6 +58,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -86,33 +88,40 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
|
||||
const uint32_t q_end =
|
||||
@@ -180,7 +189,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -201,6 +210,13 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + num_frags_z * 16;
|
||||
}
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -282,10 +298,22 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -318,6 +346,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const int ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -336,6 +365,18 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -346,7 +387,9 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -463,14 +506,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise=false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -489,6 +533,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -518,32 +563,39 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
@@ -609,7 +661,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -634,6 +686,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
@@ -716,11 +775,23 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -753,6 +824,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const uint32_t ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -771,6 +843,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -781,7 +865,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -895,7 +981,8 @@ template <typename T,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
void MultiQueryAppendC8Attention(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
@@ -953,7 +1040,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
|
||||
constexpr uint32_t smem_size =
|
||||
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -970,7 +1058,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -988,7 +1078,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1022,7 +1114,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1040,7 +1134,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1075,6 +1171,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1133,6 +1230,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1218,7 +1316,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||
constexpr uint32_t smem_size =
|
||||
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1235,7 +1334,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1253,7 +1354,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1295,7 +1398,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1313,7 +1418,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1350,6 +1457,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1424,6 +1532,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1546,6 +1655,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
@@ -1554,6 +1664,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim = meta_data.head_dims;
|
||||
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
|
||||
|
||||
DISPATCH_CAUSAL(
|
||||
causal,
|
||||
@@ -1572,43 +1683,46 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SIZE,
|
||||
{DISPATCH_BLOCKSHAPE_Q(
|
||||
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL, IsFP8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})
|
||||
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})})
|
||||
}
|
||||
|
@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t num_warps,
|
||||
uint32_t block_size,
|
||||
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
@@ -860,20 +968,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -929,7 +1044,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
} else {
|
||||
out_of_boundary =
|
||||
(causal
|
||||
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1135,16 +1252,28 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1215,16 +1346,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
|
@@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
|
@@ -18,6 +18,53 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
|
||||
// Note(ZKK)
|
||||
// This function is very easy!
|
||||
// just make HeadDim data to be new HeadDim data!
|
||||
|
||||
template <typename T, int VecSize=8, int HEAD_DIM=128, int NUM_THREADS=32>
|
||||
__device__ __forceinline__ void apply_rope(
|
||||
const T* input,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* output,
|
||||
const int thread_id) {
|
||||
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) {
|
||||
Load<T, VecSize>(&input[head_bias], &src_vec);
|
||||
const uint32_t emb_idx = head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &output[head_bias]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -28,7 +75,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -120,7 +167,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
@@ -129,6 +175,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
@@ -164,7 +211,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -270,7 +317,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -381,142 +428,6 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
if (seq_lens_encoder[ori_bi] > 0) continue;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -527,7 +438,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -641,7 +551,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -765,6 +674,293 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
T* __restrict__ cache_k_scale,
|
||||
T* __restrict__ cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + bid * max_blocks_per_seq;
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&key_cache[tgt_idx + block_i * HeadDim]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
||||
using LoadEmbT = AlignedVector<float, 1>;
|
||||
LoadKVResT cache_vec;
|
||||
LoadT src_vec1, src_vec2;
|
||||
LoadBiasT out_vec1, out_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
}
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
}
|
||||
|
||||
// rope
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
out_vec2[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
}
|
||||
if (k_norm_weight) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
||||
const uint32_t tgt_cache_idx =
|
||||
block_idx * kv_num_heads * block_size * HeadDim +
|
||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * kv_num_heads * HeadDim * block_size +
|
||||
kv_head_idx * HeadDim * block_size +
|
||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
||||
block_offset % 8 / 2 * 4 // per 4
|
||||
+ block_offset % 16 / 8 * 2 // per 2
|
||||
+ block_offset % 2; // per 1
|
||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -775,7 +971,6 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -813,44 +1008,18 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
@@ -1025,7 +1194,6 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1330,7 +1498,6 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1632,7 +1799,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2029,7 +2196,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2070,44 +2237,18 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
@@ -2327,7 +2468,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2658,7 +2799,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -3031,7 +3172,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
|
@@ -21,7 +21,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -59,7 +58,6 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -84,7 +82,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -97,7 +94,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -121,7 +117,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -138,8 +133,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -154,33 +148,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -191,7 +162,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -214,7 +184,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -238,7 +207,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -271,7 +239,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -297,7 +264,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -323,7 +289,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -349,7 +314,6 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -375,7 +339,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -410,7 +373,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -438,7 +400,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -466,7 +427,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -494,7 +454,6 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -521,7 +480,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -558,20 +516,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
@@ -582,7 +531,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -605,9 +553,39 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -617,7 +595,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -632,7 +609,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
@@ -650,7 +626,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -683,7 +658,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -717,7 +691,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -743,6 +716,36 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
@@ -750,7 +753,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -798,7 +800,6 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -828,7 +829,6 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -857,7 +857,6 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -886,7 +885,6 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
|
@@ -23,7 +23,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
|
@@ -449,8 +449,8 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize;
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
@@ -900,74 +900,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
@@ -1300,6 +1232,411 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS,
|
||||
bool is_need_kv_quant,
|
||||
bool IsFP8 = true>
|
||||
__global__ void append_write_cache_kv_c8_qkv_dynamic(
|
||||
uint8_t *__restrict__ cache_k,
|
||||
uint8_t *__restrict__ cache_v,
|
||||
const T *__restrict__ qkv_input,
|
||||
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
|
||||
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t pad_len = BLOCK_SIZE;
|
||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const T cache_k_scale = cache_k_scales[kv_head_idx];
|
||||
const T cache_v_scale = cache_v_scales[kv_head_idx];
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
const uint32_t batch_id = batch_ids[btid];
|
||||
const uint32_t tile_id = tile_ids[btid];
|
||||
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
|
||||
if (seq_len_this_time <= 0) {
|
||||
return;
|
||||
}
|
||||
const int *block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + batch_id * max_blocks_per_seq;
|
||||
|
||||
const uint32_t num_rows_per_block =
|
||||
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
|
||||
const uint32_t start_len = seq_lens_decoder[batch_id];
|
||||
const uint32_t bf_pad_len = start_len % pad_len;
|
||||
const uint32_t start_len_pad = start_len - bf_pad_len;
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_scale_smem[BLOCK_SIZE];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
|
||||
|
||||
/*
|
||||
0 | 1
|
||||
2 | 3
|
||||
*/
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
|
||||
/*
|
||||
0 | 2
|
||||
1 | 3
|
||||
*/
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
tid % 16, wid * num_frags_v * 2 + tid / 16);
|
||||
|
||||
// load kv gmem to smem
|
||||
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
|
||||
tile_id * num_rows_per_block +
|
||||
wid * num_frags_z * 16 + tid / 8;
|
||||
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++j) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y / 4;
|
||||
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
|
||||
if (chunk_start >= start_len && chunk_start < end_len) {
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<T>();
|
||||
v_read_idx += 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
|
||||
2 * num_frags_y;
|
||||
chunk_start += 4;
|
||||
k_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
v_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// reduce scale
|
||||
// 16 rows per warp
|
||||
uint32_t kv_reduce_frag[4];
|
||||
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
|
||||
|
||||
T k_local_max_value[num_frags_z * 2];
|
||||
T v_local_max_value[num_frags_z * 2];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
k_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
v_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
const int num_kv_heads = gridDim.z;
|
||||
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
|
||||
T *cache_k_scale_now = cache_k_scales + scale_offset;
|
||||
T *cache_v_scale_now = cache_v_scales + scale_offset;
|
||||
// k scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
// used for quant
|
||||
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
// v scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
|
||||
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now] = 0;
|
||||
v_scale_smem[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
|
||||
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now + 8] = 0;
|
||||
v_scale_smem[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// mask, quant, store
|
||||
using LoadKVT = AlignedVector<uint8_t, 4>;
|
||||
LoadKVT cache_vec1;
|
||||
LoadKVT cache_vec2;
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
const uint32_t write_d_stride = BLOCK_SIZE;
|
||||
uint32_t k_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t k_write_idx_now = k_write_idx_now_z +
|
||||
fy % 2 * 8 * write_b_stride +
|
||||
fy / 2 * 32; // + fy % 2 * 16;
|
||||
// load
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
|
||||
chunk_start_k + (v_id / 4) * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id - 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
|
||||
2 * num_frags_y;
|
||||
chunk_start_k += 16;
|
||||
}
|
||||
|
||||
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
|
||||
uint32_t v_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
|
||||
T v_scales[num_frags_z_v * 4];
|
||||
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
|
||||
const int offset = v_i * 16;
|
||||
const int t_offset = tid % 4 * 2;
|
||||
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
|
||||
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
|
||||
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
|
||||
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
|
||||
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
|
||||
uint32_t v_write_idx_now = v_write_idx_now_v +
|
||||
fz % 2 * 8 * write_d_stride +
|
||||
fz / 2 * 32; // + fz % 2 * 16;
|
||||
// load
|
||||
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
|
||||
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
|
||||
// store now
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
|
||||
chunk_start_v += 16;
|
||||
v_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
|
||||
}
|
||||
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
|
||||
v_smem_offset_r, wid * num_frags_v + fy) -
|
||||
16 * num_frags_z_v * num_vecs_per_head;
|
||||
chunk_start_v -= 16 * num_frags_z_v;
|
||||
}
|
||||
}
|
||||
|
||||
// Write Cache KV in Append
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
@@ -1823,7 +2160,6 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -1904,38 +2240,7 @@ void gqa_rotary_qk_variable(
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -1953,7 +2258,6 @@ void gqa_rotary_qk_variable(
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2107,10 +2411,11 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
int num_blocks_x_cpu,
|
||||
int max_seq_len,
|
||||
bool is_scale_channel_wise,
|
||||
const bool is_fp8,
|
||||
const std::string& cache_quant_type,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *cache_k_out,
|
||||
paddle::Tensor *cache_v_out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
@@ -2128,49 +2433,77 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (is_fp8) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
if (cache_quant_type != "block_wise_fp8") {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
} else {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
|
@@ -55,19 +55,9 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
@@ -135,7 +125,6 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
@@ -178,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
DISPATCH_HEAD_DIM(
|
||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||
@@ -198,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
is_scale_channel_wise,
|
||||
cache_quant_type_str == "cache_fp8",
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
|
@@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -223,13 +230,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
@@ -237,17 +238,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(), bsz);
|
||||
|
||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
@@ -258,16 +256,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
|
||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
// Clear buffer
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
@@ -275,21 +269,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu =
|
||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
@@ -314,15 +294,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
|
@@ -217,7 +217,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -235,7 +235,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -278,7 +278,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -296,7 +296,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -400,7 +400,7 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
// load v_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -418,7 +418,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal k_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
|
||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
// layout
|
||||
@@ -466,7 +466,7 @@ __global__ void append_cache_kv_c8(
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -485,7 +485,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
|
||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -614,7 +614,7 @@ __global__ void append_cache_kv_c4(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -632,7 +632,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
|
||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
|
||||
@@ -685,7 +685,7 @@ __global__ void append_cache_kv_c4(
|
||||
tid % 2 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 rows
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -704,7 +704,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
false, // is_scale_channel_wise
|
||||
cache_quant_type == "cache_fp8", // is_fp8
|
||||
cache_quant_type,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
|
@@ -18,6 +18,166 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const float*
|
||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int output_inner_dim,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
LoadInT src_vec;
|
||||
LoadFloat scale_vec;
|
||||
LoadT bias_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec;
|
||||
LoadFloat k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
const int write_q_idx =
|
||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
if (qkv_out_scales) {
|
||||
input_left *= scale_vec[2 * i];
|
||||
input_right *= scale_vec[2 * i + 1];
|
||||
}
|
||||
if (qkv_biases) {
|
||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
||||
} else {
|
||||
// write k/v
|
||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size +
|
||||
block_offset * head_size + h_bias);
|
||||
// write
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -193,7 +353,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -253,8 +414,9 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -326,7 +488,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -390,8 +553,9 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -476,7 +640,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -522,8 +687,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -583,10 +749,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
@@ -708,7 +875,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -757,8 +925,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||
&left_out_scale_vec);
|
||||
@@ -853,10 +1022,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1088,7 +1258,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1145,8 +1316,9 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1235,10 +1407,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -1431,7 +1604,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1581,10 +1755,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||
|
@@ -15,6 +15,77 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
@@ -39,7 +110,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -73,7 +145,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
@@ -96,7 +169,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +199,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -167,7 +242,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -191,7 +267,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +299,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -266,7 +344,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -292,7 +371,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
@@ -313,11 +393,15 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out) {
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -342,142 +426,184 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,11 +626,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -526,11 +656,15 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -551,11 +685,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
|
||||
template void
|
||||
@@ -578,8 +716,12 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -56,6 +56,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -103,5 +104,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -98,5 +99,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -441,6 +441,15 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
||||
if (is_dynamic_cfp8) { \
|
||||
constexpr bool IsDynamicC8 = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool IsDynamicC8 = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
|
@@ -255,7 +255,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums);
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
@@ -298,7 +299,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||
const int layer_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
@@ -306,6 +307,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -378,9 +386,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size);
|
||||
|
||||
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
@@ -564,7 +574,6 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -616,8 +625,6 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
void clear_ipc_handles(int64_t _fa);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -710,6 +717,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
@@ -753,6 +776,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -766,7 +790,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
@@ -983,7 +1008,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
"per token per block quant and padding tranpose scale");
|
||||
"per token per block quant and padding transpose scale");
|
||||
|
||||
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
|
||||
py::arg("recv_expert_count"), py::arg("block_size"),
|
||||
@@ -1026,7 +1051,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_ffn_wint2.cu
|
||||
* moe/fused_moe/moe_expert_ffn_wint2.cu
|
||||
* moe_expert_ffn_wint2
|
||||
*/
|
||||
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
|
||||
@@ -1206,8 +1231,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
@@ -1233,6 +1256,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
@@ -122,14 +122,10 @@ void register_graph_buffers(fptr_t _fa,
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
void clear_ipc_handles(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
fa->clear_ipc_handles();
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
|
@@ -517,15 +517,10 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
void clear_ipc_handles(){
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
ipc_handles_.clear();
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
||||
|
@@ -89,11 +89,11 @@ public:
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of warp-level GEMM oeprations per load for B
|
||||
/// Number of warp-level GEMM operations per load for B
|
||||
static constexpr int kWarpGemmIterationsPerLoadForB =
|
||||
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
|
||||
|
@@ -117,7 +117,7 @@ class LeftGELUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to interal compute numeric type
|
||||
// Convert source to internal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to interal compute numeric type
|
||||
// Convert source to internal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -92,7 +92,7 @@ class DualMmaBase {
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);
|
||||
|
||||
|
@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
|
||||
// adding elementwise beta, we keep this here for future useage beta_ =
|
||||
// adding elementwise beta, we keep this here for future usage beta_ =
|
||||
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
|
||||
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
|
||||
// iterator_C_.clear_mask();
|
||||
|
@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
@@ -64,7 +64,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
|
@@ -133,7 +133,7 @@ public:
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
|
@@ -509,7 +509,7 @@ public:
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
// TOOD(wangbojun) lds_converter can be remove for int8 B input
|
||||
// TODO(wangbojun) lds_converter can be remove for int8 B input
|
||||
typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
|
||||
|
@@ -96,7 +96,7 @@ public:
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
|
@@ -646,7 +646,7 @@ public:
|
||||
// );
|
||||
// }
|
||||
}
|
||||
// TOOD(wangbojun) lds_converter can be remove for int8 B input
|
||||
// TODO(wangbojun) lds_converter can be remove for int8 B input
|
||||
// int4
|
||||
// typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
// lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
|
@@ -151,34 +151,6 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -591,28 +563,3 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
@@ -171,7 +171,7 @@ struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
@@ -30,12 +30,10 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@@ -65,8 +63,6 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -51,8 +51,6 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -383,7 +383,7 @@ __global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attenti
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
inline __device__ float calculate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
|
||||
const int32_t bi = blockIdx.z;
|
||||
@@ -524,7 +524,7 @@ __global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_a
|
||||
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
|
||||
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
|
||||
|
||||
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
float inv_global_exp_sum = calculate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
|
||||
|
||||
using T_vec = Vec<cuteType, kNReducePacksize>;
|
||||
|
@@ -40,7 +40,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
if (seq_len == 0) return;
|
||||
|
||||
const int ramian_tokens = seq_len - block_idx;
|
||||
const int remain_tokens = seq_len - block_idx;
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
|
||||
@@ -51,7 +51,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
@@ -50,14 +50,14 @@ __global__ void get_kv_from_cache_c16_kernel(
|
||||
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
|
||||
|
||||
|
||||
const int ramian_tokens = seq_len - base_token_idx;
|
||||
const int remain_tokens = seq_len - base_token_idx;
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
@@ -67,7 +67,7 @@ __global__ void get_kv_from_cache_c16_kernel(
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
if (i < remain_tokens) {
|
||||
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
@@ -872,16 +872,14 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream) {
|
||||
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
||||
|
||||
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
|
||||
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
|
||||
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
|
||||
constexpr int kThreads = 128;
|
||||
if (FLAGS_hardamard_use_diagonal_block_matrix) {
|
||||
const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1
|
||||
const int VecSize = hadamard_block_size / kThreads;
|
||||
const int logN = int(ceil(std::log2(kThreads * VecSize)));
|
||||
constexpr int kNChunks = 1;
|
||||
DISPATCH_SP_VS(VecSize, VEC_SIZE, {
|
||||
@@ -991,6 +989,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::float16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1009,6 +1008,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1027,6 +1027,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::bfloat16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1045,6 +1046,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
|
@@ -32,5 +32,6 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream);
|
||||
|
@@ -80,7 +80,7 @@ void MoeDispatchKernel(
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill sucess ?)
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
@@ -35,7 +35,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -291,6 +292,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
|
||||
stream
|
||||
);
|
||||
@@ -340,6 +342,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream
|
||||
);
|
||||
@@ -403,7 +406,7 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
|
||||
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
@@ -424,7 +427,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -439,7 +443,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
@@ -458,7 +463,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
@@ -470,7 +476,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums)};
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
@@ -485,7 +492,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
@@ -499,7 +507,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
@@ -555,6 +563,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
* Affects activation function implementation
|
||||
* - estimate_total_token_nums: estimate total token numbers
|
||||
* - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization
|
||||
*
|
||||
* Note:
|
||||
* - w4a8 mode requires additional workspace memory allocation
|
||||
@@ -571,7 +581,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||
|
@@ -26,7 +26,6 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
PD_CHECK(input_shape.size() == 2);
|
||||
@@ -49,7 +48,6 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
@@ -78,7 +76,6 @@ PD_BUILD_STATIC_OP(noaux_tc)
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"renormalize: bool",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
|
@@ -25,23 +25,6 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
@@ -58,21 +41,10 @@ constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
@@ -81,8 +53,7 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
@@ -96,15 +67,7 @@ struct BitonicMerge {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
@@ -115,14 +78,13 @@ struct BitonicMerge {
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
@@ -130,16 +92,15 @@ struct BitonicSort {
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -153,37 +114,19 @@ struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -193,24 +136,7 @@ struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
@@ -218,42 +144,34 @@ struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSort {
|
||||
public:
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
@@ -275,7 +193,7 @@ class WarpSort {
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
@@ -287,11 +205,11 @@ class WarpSort {
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
@@ -316,13 +234,7 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
@@ -359,52 +271,37 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
@@ -416,8 +313,8 @@ __device__ void topk_with_k2(T* output,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
@@ -471,14 +368,8 @@ __global__ void topk_with_k2_kernel(T* output,
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -494,7 +385,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
bool const renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
@@ -513,29 +403,19 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
if (lane_id < n_group) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@@ -546,23 +426,22 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
@@ -570,11 +449,9 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
@@ -592,7 +469,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
@@ -601,45 +478,33 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
scores[s_topk_idx[i]] = value;
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
if (if_proceed_next_topk) {
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
topk_values[i] = static_cast<T>(value);
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
else {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
topk_values[i] = static_cast<float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -653,24 +518,17 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
@@ -678,19 +536,21 @@ void invokeNoAuxTc(T* scores,
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
@@ -704,7 +564,6 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
bool const renormalize, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
|
@@ -3,158 +3,6 @@
|
||||
|
||||
#include "quantization/common.cuh"
|
||||
|
||||
// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Warp‑local, no shared memory
|
||||
// • One warp handles one token.
|
||||
// • Eight tokens per 256‑thread CTA.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps)
|
||||
const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31
|
||||
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
|
||||
if (token_id >= num_tokens) return;
|
||||
|
||||
// Global tensors for this token
|
||||
const T* token_input = input + token_id * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_id * hidden_size;
|
||||
float* token_scale = output_s + token_id;
|
||||
|
||||
//
|
||||
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size
|
||||
//
|
||||
float max_value = 0.f;
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
|
||||
}
|
||||
}
|
||||
|
||||
float warp_max = warpReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
warp_max = fminf(warp_max, scale_ub);
|
||||
}
|
||||
float scale;
|
||||
scale = warp_max / FP8_E4M3_MAX;
|
||||
// Broadcast scale
|
||||
if (lane_id == 0) {
|
||||
token_scale[0] = scale;
|
||||
}
|
||||
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
|
||||
|
||||
//
|
||||
// Pass-2: quantize and write back
|
||||
//
|
||||
for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]) * scale_inv;
|
||||
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_idx * hidden_size;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
// Use template parameter for vector size
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
// Find max using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
max_value = fminf(max_value, scale_ub);
|
||||
}
|
||||
__shared__ float scale;
|
||||
if (tid == 0) {
|
||||
scale = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_inv = 1.0f / scale;
|
||||
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@@ -331,78 +179,39 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
auto rank = input.dims().size();
|
||||
int const hidden_size = input.dims()[rank - 1];
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
if (hidden_size % 8 == 0){
|
||||
int device = 0;
|
||||
cudaGetDevice(&device);
|
||||
int sm_count = 0;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
|
||||
const int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_size % 16 == 0);
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
if (use_warp_kernel) {
|
||||
// -------- warp‑local ---------------------------------------------------
|
||||
constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256
|
||||
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
} else {
|
||||
// -------- baseline -----------------------------------------------------
|
||||
constexpr int THREADS = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
||||
|
@@ -15,31 +15,72 @@
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void recover_spec_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
int64_t *draft_tokens,
|
||||
const int64_t *step_draft_tokens,
|
||||
const int *step_seq_lens_this_time,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
const int draft_tokens_len,
|
||||
const int num_extra_tokens) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
|
||||
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
|
||||
if (block_table_now[max_possible_block_idx] != -1) {
|
||||
// can be recovered for decoding
|
||||
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
|
||||
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
|
||||
draft_tokens_now[i] = step_draft_tokens_now[i];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
@@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
if (draft_tokens) {
|
||||
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
|
||||
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
|
||||
step_draft_tokens.get_ptr()->data<int64_t>(),
|
||||
step_seq_lens_this_time.get_ptr()->data<int>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
max_draft_tokens * 2 + 1);
|
||||
|
||||
} else {
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
@@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"block_tables",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
"is_block_step",
|
||||
paddle::Optional("draft_tokens"),
|
||||
paddle::Optional("step_draft_tokens"),
|
||||
paddle::Optional("step_seq_lens_this_time")})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
|
@@ -75,7 +75,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -45,7 +45,7 @@ void save_kernel(const paddle::Tensor& x,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -34,7 +34,7 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags,
|
||||
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
|
||||
const int seq_len_dec = seq_lens_decoder[tid];
|
||||
const int seq_len_enc = seq_lens_encoder[tid];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
|
||||
if (step_idx[tid] >= 0) {
|
||||
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
|
||||
pre_ids_all_now[step_idx[tid]] = input_ids_now[seq_len_enc - 1];
|
@@ -15,7 +15,48 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
|
||||
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
|
||||
do { \
|
||||
constexpr int BlockSize = BLOCK_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
|
||||
do { \
|
||||
if (truncate_first_token) { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
|
||||
do { \
|
||||
if (kvcache_scheduler_v1) { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
|
||||
do { \
|
||||
if (splitwise_prefill) { \
|
||||
constexpr bool SPLITWISE_PREFILL = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool SPLITWISE_PREFILL = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
// 1. process block_step situation
|
||||
// -- In v0 mode, block_step will drop mtp query.
|
||||
// -- In v1 mode, block_step will continue to infer.
|
||||
if constexpr(KVCACHE_SCHEDULER_V1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. process normal query, not in any special case.
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
// prefill generation
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
} else { // decode generation
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
// TODO: check
|
||||
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
void DispatchRunner(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -291,6 +261,7 @@ void DispatchTokenMode(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -310,75 +281,79 @@ void DispatchTokenMode(
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
DISPATCH_BLOCKSIZE(512, {
|
||||
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
|
||||
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
|
||||
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
|
||||
if constexpr (SPLITWISE_PREFILL) {
|
||||
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
@@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
DispatchRunner(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
|
@@ -63,7 +63,7 @@ __global__ void ComputeOrderKernel(
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
in_offset += cur_base_model_seq_lens_this_time - accept_num;
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// (liuzichang): Temporary Reserved for debug
|
||||
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
@@ -139,6 +139,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
||||
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||
.Inputs({"input_ids",
|
||||
"draft_tokens",
|
||||
"cum_offsets"
|
||||
"token_num",
|
||||
"seq_len",
|
||||
"seq_lens_encoder"})
|
||||
|
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_schedula_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq
|
||||
);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"step_draft_tokens",
|
||||
"step_seq_lens_this_time",
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"step_draft_tokens_out",
|
||||
"step_seq_lens_this_time_out",
|
||||
"accept_num_out",
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
@@ -35,7 +35,7 @@ __global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
|
||||
accept_tokens + tid * max_draft_tokens;
|
||||
const int seq_len_dec = seq_lens_decoder[tid];
|
||||
const int seq_len_enc = seq_lens_encoder[tid];
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
|
||||
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
|
||||
// printf("step_idx[tid] %d\n", step_idx[tid]);
|
||||
if (step_idx[tid] >= 0) {
|
||||
for (int i = 0; i < accept_num[tid]; i++) {
|
@@ -295,7 +295,7 @@ void SpeculateStepSchedule(const paddle::Tensor &stop_flags,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -283,7 +283,7 @@ void Schedule(const paddle::Tensor &stop_flags,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -58,7 +58,7 @@ class TokenTransfer {
|
||||
}
|
||||
|
||||
// once copy: cpu --> cpu
|
||||
// arrary length should be (1 + MAX_BATCH)
|
||||
// array length should be (1 + MAX_BATCH)
|
||||
bool GetBatchToken(int64_t *array) {
|
||||
if (Empty()) {
|
||||
return false;
|
||||
|
@@ -1,71 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "cuda_multiprocess.h"
|
||||
|
||||
#if !defined(_WIN32)
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
|
||||
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd)
|
||||
static inline int sharedMemoryUnlinkByName(const char* name) {
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
|
||||
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
|
||||
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
|
||||
if (hMap) {
|
||||
CloseHandle(hMap);
|
||||
return 0;
|
||||
}
|
||||
// 已经不存在也算成功
|
||||
return 0;
|
||||
#else
|
||||
// POSIX: 移除名字,未来不可再 open;已映射区仍存活直至 munmap
|
||||
if (shm_unlink(name) != 0) {
|
||||
if (errno == ENOENT) return 0; // 不存在视作成功
|
||||
return errno;
|
||||
}
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void UnsetDataIpc(const paddle::Tensor& tmp_input,
|
||||
const std::string& shm_name,
|
||||
bool close_ipc,
|
||||
bool unlink_shm) {
|
||||
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
|
||||
if (close_ipc) {
|
||||
void* ptr = const_cast<void*>(tmp_input.data());
|
||||
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
|
||||
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
|
||||
if (unlink_shm) {
|
||||
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
|
||||
if (rc != 0) {
|
||||
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
|
||||
shm_name.c_str(), rc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(unset_data_ipc)
|
||||
.Inputs({"tmp_input"})
|
||||
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
|
||||
.SetKernelFn(PD_KERNEL(UnsetDataIpc));
|
@@ -75,10 +75,10 @@ void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
|
||||
const int max_seq_len,
|
||||
const int max_batch_size,
|
||||
const int split_fuse_size) {
|
||||
dim3 girds;
|
||||
girds.x = max_batch_size;
|
||||
dim3 grids;
|
||||
grids.x = max_batch_size;
|
||||
const int block_size = 128;
|
||||
update_split_fuse_inputs_kernel<<<girds,
|
||||
update_split_fuse_inputs_kernel<<<grids,
|
||||
block_size,
|
||||
0,
|
||||
input_ids.stream()>>>(
|
||||
|
@@ -110,7 +110,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill sucess ?)
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
@@ -37,52 +37,6 @@ def load_module_from_path(module_name, path):
|
||||
return module
|
||||
|
||||
|
||||
def update_git_repo():
|
||||
try:
|
||||
print("update third party repo...", flush=True)
|
||||
original_dir = os.getcwd()
|
||||
submodule_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
third_party_path = os.path.join(submodule_dir, "third_party")
|
||||
root_path = Path(third_party_path)
|
||||
|
||||
# check if third_party is empty
|
||||
update_third_party = False
|
||||
for dirpath in root_path.iterdir():
|
||||
if dirpath.is_dir():
|
||||
has_content = any(dirpath.iterdir())
|
||||
if not has_content:
|
||||
update_third_party = True
|
||||
|
||||
if update_third_party:
|
||||
os.chdir(submodule_dir)
|
||||
subprocess.run(
|
||||
"git submodule sync --recursive && git submodule update --init --recursive",
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# apply deep gemm patch
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
patch_source = os.path.join(submodule_dir, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
if not os.path.exists(patch_destination):
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
os.chdir(dst_path)
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(original_dir)
|
||||
except subprocess.CalledProcessError:
|
||||
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
# cannot import envs directly because it depends on fastdeploy,
|
||||
@@ -92,8 +46,6 @@ envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.
|
||||
archs = json.loads(envs.FD_BUILDING_ARCS)
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
update_git_repo()
|
||||
|
||||
|
||||
def download_and_extract(url, destination_directory):
|
||||
"""
|
||||
@@ -126,6 +78,52 @@ def download_and_extract(url, destination_directory):
|
||||
print(f"Error extracting file: {e}")
|
||||
|
||||
|
||||
def clone_git_repo(version, repo_url, destination_path):
|
||||
"""
|
||||
Clone git repo to destination path.
|
||||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"-b",
|
||||
version,
|
||||
"--single-branch",
|
||||
repo_url,
|
||||
destination_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
|
||||
"""
|
||||
reset git repo to destination commit and apply patch.
|
||||
"""
|
||||
if commit_id is not None:
|
||||
reset_cmd = ["git", "reset", "--hard", commit_id]
|
||||
if patch is not None:
|
||||
patch_source = os.path.join(cur_path, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
|
||||
try:
|
||||
os.chdir(dst_path)
|
||||
if commit_id is not None:
|
||||
subprocess.run(reset_cmd, check=True)
|
||||
if patch is not None:
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(cur_path)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def get_sm_version(archs):
|
||||
"""
|
||||
Get sm version of paddle.
|
||||
@@ -193,13 +191,20 @@ def find_end_files(directory, end_str):
|
||||
if paddle.is_compiled_with_rocm():
|
||||
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
||||
# so we need to check if paddle compiled with rocm at first.
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
"gpu_ops/get_output_msg_with_topk.cc",
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/token_penalty_multi_scores.cu",
|
||||
"gpu_ops/stop_generation.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
@@ -208,7 +213,6 @@ if paddle.is_compiled_with_rocm():
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/step_system_cache.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -219,7 +223,7 @@ if paddle.is_compiled_with_rocm():
|
||||
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_save_output.cc",
|
||||
"gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_step.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
|
||||
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
|
||||
@@ -257,7 +261,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/set_mask_value.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/ngram_mask.cu",
|
||||
"gpu_ops/gather_idx.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -272,14 +276,13 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/step_reschedule.cu",
|
||||
"gpu_ops/fused_get_rope.cu",
|
||||
"gpu_ops/fused_get_rotary_embedding.cu",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/update_inputs_beam.cu",
|
||||
"gpu_ops/beam_search_softmax.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/read_data_ipc.cu",
|
||||
"gpu_ops/enforce_generation.cu",
|
||||
"gpu_ops/dequant_int8.cu",
|
||||
@@ -313,6 +316,28 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
|
||||
]
|
||||
|
||||
cutlass_dir = "third_party/cutlass"
|
||||
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
||||
if not os.path.exists(cutlass_dir):
|
||||
os.makedirs(cutlass_dir)
|
||||
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
|
||||
if not os.listdir(cutlass_dir):
|
||||
raise ValueError("Git clone cutlass failed!")
|
||||
|
||||
# deep gemm
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
||||
if not os.path.exists(deep_gemm_dir):
|
||||
os.makedirs(deep_gemm_dir)
|
||||
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
|
||||
if not os.listdir(deep_gemm_dir):
|
||||
raise ValueError("Git clone DeepGEMM failed!")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dst_path = os.path.join(cur_path, deep_gemm_dir)
|
||||
commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
process_git_repo(cur_path, dst_path, commit_id, patch)
|
||||
|
||||
dg_third_party_include_dirs = (
|
||||
"third_party/cutlass/include/cute",
|
||||
"third_party/cutlass/include/cutlass",
|
||||
@@ -340,6 +365,14 @@ elif paddle.is_compiled_with_cuda():
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
||||
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
|
||||
cc_compile_args = []
|
||||
nvcc_compile_args = get_gencode_flags(archs)
|
||||
nvcc_compile_args += ["-DPADDLE_DEV"]
|
||||
@@ -474,7 +507,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
||||
|
||||
if cc >= 90 and nvcc_version >= 12.0:
|
||||
# Hopper optmized mla
|
||||
# Hopper optimized mla
|
||||
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
||||
sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
|
||||
sources += find_end_files("gpu_ops/moba_attn/moba_decoder_attn/", ".cu")
|
||||
@@ -527,7 +560,7 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
@@ -560,6 +593,13 @@ elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
)
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/update_inputs_v1.cu",
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
@@ -569,7 +609,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"gpu_ops/save_with_output.cc",
|
||||
"gpu_ops/set_mask_value.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/set_value_by_flags_and_idx.cu",
|
||||
"gpu_ops/ngram_mask.cu",
|
||||
"gpu_ops/gather_idx.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -578,7 +618,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/stop_generation.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
"gpu_ops/set_flags.cu",
|
||||
"gpu_ops/fused_get_rope.cu",
|
||||
"gpu_ops/fused_get_rotary_embedding.cu",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/update_inputs_beam.cu",
|
||||
|
1
custom_ops/third_party/DeepGEMM
vendored
1
custom_ops/third_party/DeepGEMM
vendored
Submodule custom_ops/third_party/DeepGEMM deleted from 95e81b3dd6
1
custom_ops/third_party/cutlass
vendored
1
custom_ops/third_party/cutlass
vendored
Submodule custom_ops/third_party/cutlass deleted from afa1772203
1
custom_ops/third_party/nlohmann_json
vendored
1
custom_ops/third_party/nlohmann_json
vendored
Submodule custom_ops/third_party/nlohmann_json deleted from 9cca280a4d
@@ -67,7 +67,7 @@ std::vector<paddle::Tensor> MoeLayerKernel(
|
||||
const auto xtype = x.dtype();
|
||||
auto x_dims = x.shape();
|
||||
auto up_gate_proj_dims = up_gate_proj_weight.shape();
|
||||
PD_CHECK(x_dims.size() == 2, "x_dims.size() shoud be 2.");
|
||||
PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2.");
|
||||
PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3.");
|
||||
PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support.");
|
||||
if (quant_method == "weight_only_int4") {
|
||||
|
@@ -122,7 +122,7 @@ void SpeculateStepSchedule(
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -59,7 +59,7 @@ void SaveOutMmsg(const paddle::Tensor &x, const paddle::Tensor ¬_need_stop,
|
||||
std::string inference_msg_id_env_str(inference_msg_id_env_p);
|
||||
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
|
||||
if (inference_msg_id_from_env == 2) {
|
||||
// 2 and -2 is perserve for no-output indication.
|
||||
// 2 and -2 is preserve for no-output indication.
|
||||
throw std::runtime_error(
|
||||
" INFERENCE_MSG_ID cannot be 2, please use other number.");
|
||||
}
|
||||
|
@@ -4,7 +4,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_LM_SIZE 28672
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define ALIGNMENT 64
|
||||
|
@@ -4,7 +4,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_LM_SIZE 28672
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define ALIGNMENT 64
|
||||
|
@@ -8,7 +8,7 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
#define MAX_SM_SIZE 32768
|
||||
// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is
|
||||
// the stack space
|
||||
#define MAX_BATCH 512
|
||||
#define BANK_CONFLICT_M 128
|
||||
|
@@ -79,7 +79,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
||||
# print("wscale_pd:\n{}".format(wscale_pd))
|
||||
# print("wscale_np:\n{}".format(wscale_np))
|
||||
|
||||
# comparation
|
||||
# comparison
|
||||
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
|
||||
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
|
||||
print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user