mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Compare commits
95 Commits
develop
...
Jason/expe
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8614ca56ad | ||
![]() |
c35a21a99a | ||
![]() |
c8985727a6 | ||
![]() |
076c30cb0f | ||
![]() |
f8c6a354a1 | ||
![]() |
b176cba474 | ||
![]() |
dcf633c4d9 | ||
![]() |
213f15ef55 | ||
![]() |
bab779011c | ||
![]() |
e2b68b33c9 | ||
![]() |
8a506500f3 | ||
![]() |
1aab1c8d06 | ||
![]() |
94b6e7a341 | ||
![]() |
389c5dd3a2 | ||
![]() |
361104508e | ||
![]() |
0bfffdbc14 | ||
![]() |
f489c9f8ef | ||
![]() |
be98f6e950 | ||
![]() |
f75697c2d1 | ||
![]() |
1e86418c4a | ||
![]() |
5027ed7239 | ||
![]() |
25aa2d94aa | ||
![]() |
b6caf6e622 | ||
![]() |
d381fa8194 | ||
![]() |
d2ab369427 | ||
![]() |
2883746132 | ||
![]() |
2485333f71 | ||
![]() |
10768a4d79 | ||
![]() |
c64ceac34d | ||
![]() |
447297a7b5 | ||
![]() |
63d24b2210 | ||
![]() |
48f2ab3fb3 | ||
![]() |
749f074e44 | ||
![]() |
f06e3ee1fc | ||
![]() |
2f473ba966 | ||
![]() |
cce2410fad | ||
![]() |
d8985a7a21 | ||
![]() |
7d1b2bd732 | ||
![]() |
71a9127e13 | ||
![]() |
8f5397616f | ||
![]() |
ece070cf6b | ||
![]() |
d40a1046de | ||
![]() |
fa2369271d | ||
![]() |
8903f937f9 | ||
![]() |
1023a67765 | ||
![]() |
d43549953c | ||
![]() |
c7c1627456 | ||
![]() |
d6bf6de5e6 | ||
![]() |
38e734e183 | ||
![]() |
051e4a881c | ||
![]() |
b2bb37d7c0 | ||
![]() |
c6e2a37a95 | ||
![]() |
8d77c1cb51 | ||
![]() |
41cd3e24c9 | ||
![]() |
11b18e5ef0 | ||
![]() |
e2c764fd5a | ||
![]() |
2d975e16b0 | ||
![]() |
8915c8411d | ||
![]() |
77c1bd0813 | ||
![]() |
473cde779f | ||
![]() |
335d1c8e8f | ||
![]() |
173e4df982 | ||
![]() |
199f88ce1e | ||
![]() |
55ebe855c0 | ||
![]() |
deb7ad205f | ||
![]() |
e9f72df918 | ||
![]() |
8567ada09e | ||
![]() |
afcde19277 | ||
![]() |
d40d3a5a4f | ||
![]() |
b8d0f1c081 | ||
![]() |
8550e19008 | ||
![]() |
a0c03510c0 | ||
![]() |
fb1e0d6a87 | ||
![]() |
fbf0e9d2aa | ||
![]() |
8c0e7d6fe9 | ||
![]() |
b56b015d85 | ||
![]() |
1432e336d7 | ||
![]() |
9213a58a06 | ||
![]() |
87ef0f5d30 | ||
![]() |
abcd2148c0 | ||
![]() |
05b6591c23 | ||
![]() |
42402c80e9 | ||
![]() |
1968c65849 | ||
![]() |
37cb37b7f2 | ||
![]() |
f975f7de2f | ||
![]() |
174510180a | ||
![]() |
5cda326ba2 | ||
![]() |
a6c8f17431 | ||
![]() |
cd09384a14 | ||
![]() |
0f42771a84 | ||
![]() |
d1d063e4af | ||
![]() |
a86b35ab49 | ||
![]() |
0cdbc950b5 | ||
![]() |
2b0a745d57 | ||
![]() |
1953c7c759 |
1
.github/workflows/Codestyle-Check.yml
vendored
1
.github/workflows/Codestyle-Check.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
|
3
.github/workflows/_accuracy_test.yml
vendored
3
.github/workflows/_accuracy_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -160,7 +160,6 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
6
.github/workflows/_base_test.yml
vendored
6
.github/workflows/_base_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -143,8 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.3.0.dev20250917 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
@@ -161,7 +160,6 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
11
.github/workflows/_build_linux.yml
vendored
11
.github/workflows/_build_linux.yml
vendored
@@ -55,7 +55,7 @@ on:
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-Build]
|
||||
timeout-minutes: 360
|
||||
timeout-minutes: 240
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
@@ -106,12 +106,7 @@ jobs:
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
|
||||
IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}"
|
||||
len=${#parts[@]}
|
||||
CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")"
|
||||
echo "$CCACHE_DEFAULT_DIR"
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}"
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
@@ -132,7 +127,6 @@ jobs:
|
||||
-e "PADDLEVERSION=${PADDLEVERSION}" \
|
||||
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
|
||||
-e "BRANCH_REF=${BRANCH_REF}" \
|
||||
-e "CCACHE_MAXSIZE=50G" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
if [[ -n "${FD_VERSION}" ]]; then
|
||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||
@@ -140,7 +134,6 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
|
73
.github/workflows/_ci_image_build.yml
vendored
73
.github/workflows/_ci_image_build.yml
vendored
@@ -1,73 +0,0 @@
|
||||
name: Docker Build
|
||||
description: "FastDeploy CI Image Build"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
CI_DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: false
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
outputs:
|
||||
docker_name_precheck:
|
||||
description: "Output path of the generated wheel"
|
||||
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
|
||||
|
||||
jobs:
|
||||
docker_build:
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
outputs:
|
||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
||||
steps:
|
||||
- name: Docker Build
|
||||
id: docker_build
|
||||
shell: bash
|
||||
env:
|
||||
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
# Docker Build
|
||||
cd tools/dockerfile/
|
||||
set -e
|
||||
cp ../../requirements.txt ./
|
||||
cp ../../scripts/unittest_requirement.txt ./
|
||||
docker build -t ${docker_image_name} -f Dockerfile.ci . \
|
||||
--network host \
|
||||
--no-cache
|
||||
docker push ${docker_image_name}
|
||||
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT
|
3
.github/workflows/_logprob_test_linux.yml
vendored
3
.github/workflows/_logprob_test_linux.yml
vendored
@@ -39,7 +39,6 @@ jobs:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
||||
run: |
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -117,6 +116,7 @@ jobs:
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
@@ -147,7 +147,6 @@ jobs:
|
||||
--skip install
|
||||
|
||||
cd PaddleTest/framework/ServeTest
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
5
.github/workflows/_pre_ce_test.yml
vendored
5
.github/workflows/_pre_ce_test.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -82,9 +82,6 @@ jobs:
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
|
2
.github/workflows/_stable_test.yml
vendored
2
.github/workflows/_stable_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
|
9
.github/workflows/_unit_test_coverage.yml
vendored
9
.github/workflows/_unit_test_coverage.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
run_tests_with_coverage:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 90
|
||||
timeout-minutes: 60
|
||||
needs: check_cov_skip
|
||||
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
||||
outputs:
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
@@ -171,7 +171,10 @@ jobs:
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install -r scripts/unittest_requirement.txt
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install pytest-cov
|
||||
python -m pip install jsonschema aistudio_sdk==0.3.5
|
||||
python -m pip install ${fd_wheel_url}
|
||||
rm -rf fastdeploy
|
||||
# coverage subprocess use
|
||||
|
1
.github/workflows/approve.yml
vendored
1
.github/workflows/approve.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
21
.github/workflows/ce_job.yml
vendored
21
.github/workflows/ce_job.yml
vendored
@@ -6,10 +6,11 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/experimental_feature*'
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: CE-Job-${{ github.ref }}-${{ github.sha }}
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
@@ -199,13 +200,13 @@ jobs:
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
||||
ce_upload_sm8689:
|
||||
@@ -224,9 +225,9 @@ jobs:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
@@ -238,11 +239,11 @@ jobs:
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
6
.github/workflows/ci_iluvatar.yml
vendored
6
.github/workflows/ci_iluvatar.yml
vendored
@@ -28,22 +28,18 @@ jobs:
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global http.proxy "http://61.151.249.150:33128"
|
||||
git config --global https.proxy "http://61.151.249.150:33128"
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone --recursive ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
|
174
.github/workflows/ci_image_update.yml
vendored
174
.github/workflows/ci_image_update.yml
vendored
@@ -1,174 +0,0 @@
|
||||
name: CI Images Build
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
||||
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
ci_image_build:
|
||||
name: CI Images Build
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_ci_image_build.yml
|
||||
with:
|
||||
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, ci_image_build]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
|
||||
publish_pre_check:
|
||||
name: Publish Docker Images Pre Check
|
||||
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
steps:
|
||||
- name: Images Uploading
|
||||
env:
|
||||
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
run: |
|
||||
echo "images_name=${images_name}"
|
||||
docker images ${ci_image_name}
|
||||
docker tag ${images_name} ${ci_image_name}
|
||||
docker push ${ci_image_name}
|
1
.github/workflows/ci_xpu.yml
vendored
1
.github/workflows/ci_xpu.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
4
.github/workflows/pr_build_and_test.yml
vendored
4
.github/workflows/pr_build_and_test.yml
vendored
@@ -2,7 +2,7 @@ name: PR Build and Test
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches: [develop, release/**]
|
||||
branches: [develop, release/**, feature/**]
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
COMPILE_ARCH: "89,90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
FD_VERSION: "0.0.0"
|
||||
|
||||
|
12
.github/workflows/publish_job.yml
vendored
12
.github/workflows/publish_job.yml
vendored
@@ -13,7 +13,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
@@ -319,13 +319,3 @@ jobs:
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
10
.gitmodules
vendored
10
.gitmodules
vendored
@@ -1,10 +0,0 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
ignore = all
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
20
README.md
20
README.md
@@ -26,8 +26,6 @@ English | [简体中文](README_CN.md)
|
||||
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
|
||||
## News
|
||||
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
|
||||
|
||||
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
|
||||
|
||||
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
@@ -43,7 +41,7 @@ English | [简体中文](README_CN.md)
|
||||
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
|
||||
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
|
||||
- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
|
||||
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc.
|
||||
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -59,10 +57,8 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
||||
- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU and MetaX GPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
## Get Started
|
||||
|
||||
@@ -72,12 +68,20 @@ Learn how to use FastDeploy through our documentation:
|
||||
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
|
||||
- [Offline Inference Development](./docs/offline_inference.md)
|
||||
- [Online Service Deployment](./docs/online_serving/README.md)
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
- [Best Practices](./docs/best_practices/README.md)
|
||||
|
||||
## Supported Models
|
||||
|
||||
Learn how to download models, enable using the torch format, and more:
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
|
22
README_CN.md
22
README_CN.md
@@ -26,9 +26,7 @@
|
||||
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
|
||||
|
||||
## 最新活动
|
||||
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容,性能进一步优化,更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
|
||||
|
||||
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
**[2025-08] 🔥 FastDeploy v2.1 全新发布:** 全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
@@ -41,7 +39,7 @@
|
||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi等
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
||||
|
||||
## 要求
|
||||
|
||||
@@ -57,10 +55,8 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
||||
- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 和 沐曦(MetaX)GPU 在内的其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
## 入门指南
|
||||
|
||||
@@ -70,12 +66,20 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
|
||||
- [离线推理](./docs/zh/offline_inference.md)
|
||||
- [在线服务](./docs/zh/online_serving/README.md)
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
- [最佳实践](./docs/zh/best_practices/README.md)
|
||||
|
||||
## 支持模型列表
|
||||
|
||||
通过我们的文档了解如何下载模型,如何支持torch格式等:
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|
||||
## 进阶用法
|
||||
|
||||
|
@@ -98,7 +98,7 @@ def main(args):
|
||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Warmup
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
|
@@ -965,7 +965,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="openai-chat",
|
||||
default="vllm",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@@ -1,5 +0,0 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
tensor_parallel_size: 4
|
||||
use_cudagraph: True
|
||||
load_choices: "default_v1"
|
@@ -1,6 +0,0 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
tensor_parallel_size: 4
|
||||
use_cudagraph: True
|
||||
load_choices: "default_v1"
|
||||
quantization: wfp8afp8
|
@@ -6,4 +6,3 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
@@ -1,6 +0,0 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
quantization: wint4
|
||||
max_num_batched_tokens: 8192
|
||||
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
@@ -6,4 +6,3 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint8
|
||||
|
@@ -1,5 +0,0 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 256
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 4
|
||||
gpu_memory_utilization: 0.9
|
@@ -13,4 +13,3 @@ pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
@@ -10,4 +10,3 @@ engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
quantization: wint4
|
||||
|
@@ -1,11 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 56
|
||||
gpu_memory_utilization: 0.8
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint4
|
||||
limit_mm_per_prompt: '{"image": 100, "video": 100}'
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.9
|
||||
gpu_memory_utilization: 0.95
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.85
|
||||
gpu_memory_utilization: 0.8
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -1,9 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,10 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint4
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,10 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint8
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1 +0,0 @@
|
||||
max_tokens: 131071
|
@@ -1 +0,0 @@
|
||||
max_tokens: 12288
|
@@ -1,8 +0,0 @@
|
||||
top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 131071
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
@@ -2,7 +2,7 @@ top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 12288
|
||||
max_tokens: 65535
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
@@ -1,6 +0,0 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
reasoning_parser: ernie_x1
|
||||
tool_call_parser: ernie_x1
|
||||
load_choices: "default_v1"
|
14
build.sh
14
build.sh
@@ -128,12 +128,6 @@ function copy_ops(){
|
||||
echo -e "MACA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
||||
if [ "$is_intel_hpu" = "True" ]; then
|
||||
DEVICE_TYPE="intel-hpu"
|
||||
echo -e "intel_hpu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cd ../../../../
|
||||
@@ -149,9 +143,9 @@ function build_and_install_ops() {
|
||||
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
||||
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
||||
if [ "$is_xpu" = "True" ]; then
|
||||
cd xpu_ops
|
||||
cd xpu_ops/src
|
||||
bash build.sh ${TMP_DIR_REAL_PATH}
|
||||
cd ..
|
||||
cd ../..
|
||||
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
|
||||
if [ "$FD_BUILDING_ARCS" == "" ]; then
|
||||
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||
@@ -165,9 +159,7 @@ function build_and_install_ops() {
|
||||
else
|
||||
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||
fi
|
||||
if [ -d "${OPS_TMP_DIR}" ]; then
|
||||
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
||||
fi
|
||||
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
||||
else
|
||||
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
|
||||
exit 1
|
||||
|
@@ -14,7 +14,7 @@
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void set_value_by_flags_and_idx(const bool *stop_flags,
|
||||
void set_value_by_flag_and_id(const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
|
||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
||||
set_value_by_flag_and_id(stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
|
@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
|
||||
void UpdateInputs(const paddle::Tensor &stop_flags,
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"input_ids", "input_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||
|
@@ -317,6 +317,7 @@ void AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
@@ -343,6 +344,7 @@ void AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
|
@@ -52,7 +52,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -75,11 +74,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -148,7 +142,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -428,7 +422,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -452,11 +445,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -523,7 +511,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -914,7 +902,6 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -973,7 +960,6 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1148,7 +1134,6 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1221,7 +1206,6 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -57,7 +57,6 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -86,11 +85,6 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -179,7 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -526,7 +520,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -556,11 +549,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -647,7 +635,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -1119,7 +1107,6 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1184,7 +1171,6 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1379,7 +1365,6 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1460,7 +1445,6 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -58,7 +58,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -88,11 +87,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -189,7 +183,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -210,12 +204,16 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
T* k_smem_scale_ptr = nullptr;
|
||||
T* v_smem_scale_ptr = nullptr;
|
||||
smem_t k_scale_smem;
|
||||
smem_t v_scale_smem;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
k_smem_scale_ptr = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + num_frags_z * 16;
|
||||
v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16;
|
||||
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
|
||||
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
|
||||
}
|
||||
|
||||
|
||||
@@ -277,6 +275,20 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -294,24 +306,32 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale_ptr,
|
||||
cache_k_scale_reg
|
||||
);
|
||||
}
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
@@ -364,21 +384,29 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale_ptr,
|
||||
cache_v_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// compute sfm*v
|
||||
compute_sfm_v_c8<num_frags_x,
|
||||
@@ -409,6 +437,20 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
|
||||
}
|
||||
@@ -533,7 +575,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -563,11 +604,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -661,7 +697,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -686,12 +722,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
T* k_smem_scale_ptr = nullptr;
|
||||
T* v_smem_scale_ptr = nullptr;
|
||||
smem_t k_scale_smem;
|
||||
smem_t v_scale_smem;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
|
||||
k_smem_scale_ptr = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16;
|
||||
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
|
||||
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -755,6 +795,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -772,23 +826,31 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale_ptr,
|
||||
cache_k_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
@@ -842,21 +904,29 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale_ptr,
|
||||
cache_v_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// compute sfm * v
|
||||
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
|
||||
@@ -887,6 +957,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
}
|
||||
wait_group<0>();
|
||||
@@ -1171,7 +1255,6 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1230,7 +1313,6 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1457,7 +1539,6 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1532,7 +1613,6 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -384,53 +384,40 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
template<SharedMemFillMode fill_mode,
|
||||
uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
|
||||
smem_t kv_scale_smem,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const T* cache_kv_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
||||
if (tid < block_size / 8) {
|
||||
const T* cache_k_scale_now = cache_kv_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size + tid * 8;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
if (tid < block_size / 8 * 2) {
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
const T* cache_k_scale_now = cache_kv_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size + tid % 8 * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -439,54 +426,59 @@ template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
T* cache_v_reg
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1044,7 +1036,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||
} else {
|
||||
out_of_boundary =
|
||||
(causal
|
||||
|
@@ -18,53 +18,6 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
|
||||
// Note(ZKK)
|
||||
// This function is very easy!
|
||||
// just make HeadDim data to be new HeadDim data!
|
||||
|
||||
template <typename T, int VecSize=8, int HEAD_DIM=128, int NUM_THREADS=32>
|
||||
__device__ __forceinline__ void apply_rope(
|
||||
const T* input,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* output,
|
||||
const int thread_id) {
|
||||
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) {
|
||||
Load<T, VecSize>(&input[head_bias], &src_vec);
|
||||
const uint32_t emb_idx = head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &output[head_bias]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -75,7 +28,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -211,7 +164,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -317,7 +270,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -428,142 +381,6 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -574,6 +391,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -687,6 +505,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -820,6 +639,7 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1107,6 +927,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1144,18 +965,44 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
@@ -1330,6 +1177,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1634,6 +1482,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1935,7 +1784,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2332,7 +2181,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2373,18 +2222,44 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
@@ -2604,7 +2479,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -2935,7 +2810,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -3308,7 +3183,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
|
@@ -21,6 +21,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -58,6 +59,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -82,6 +84,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -94,7 +97,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -118,6 +120,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -134,34 +137,13 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -175,7 +157,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -186,6 +167,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -208,6 +190,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -231,6 +214,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -263,6 +247,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -288,6 +273,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -313,6 +299,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -338,6 +325,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -363,6 +351,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -397,6 +386,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -424,6 +414,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -451,6 +442,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -478,6 +470,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -504,6 +497,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -540,20 +534,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
@@ -564,6 +549,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -598,6 +584,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -628,6 +615,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -642,7 +630,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
@@ -660,6 +647,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -692,6 +680,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -725,6 +714,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -762,6 +752,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -787,6 +778,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -834,6 +826,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -863,6 +856,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -891,6 +885,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -919,6 +914,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
|
@@ -23,6 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
|
@@ -449,8 +449,8 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
@@ -900,74 +900,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
@@ -1004,8 +936,7 @@ __global__ void cache_kernel(
|
||||
const uint32_t qkv_bias = bias % hidden_size;
|
||||
const uint32_t hi = qkv_bias / head_size;
|
||||
const uint32_t h_bias = qkv_bias % head_size;
|
||||
const int32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (ori_bi == -1) continue; // skip batch_id_per_token[token_idx]=-1
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
@@ -2229,7 +2160,6 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -2310,38 +2240,7 @@ void gqa_rotary_qk_variable(
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -2359,7 +2258,6 @@ void gqa_rotary_qk_variable(
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -55,19 +55,9 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
@@ -135,7 +125,6 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
|
@@ -11,11 +11,10 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
@@ -117,93 +116,6 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ num_blocks_x,
|
||||
int *__restrict__ res_chunk_size,
|
||||
const int bsz,
|
||||
const int set_chunk_size,
|
||||
const int block_size,
|
||||
const int sm_cout) {
|
||||
const uint32_t conf_id = threadIdx.x;
|
||||
int gridx = 0;
|
||||
if (set_chunk_size > 0 && conf_id == 0) {
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
*res_chunk_size = set_chunk_size;
|
||||
} else if (conf_id < config_size) {
|
||||
__shared__ int gridx_shared[config_size];
|
||||
// chunk_size is a multiple of 64
|
||||
const int chunk_size = block_size << conf_id;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
gridx_shared[conf_id] = gridx;
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
const int bsz,
|
||||
const int chunk_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
|
||||
if (seq_len == 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
@@ -279,23 +191,14 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
}
|
||||
}
|
||||
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -320,120 +223,31 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(), bsz);
|
||||
|
||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int sm_cout;
|
||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
||||
constexpr int config_size =
|
||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
||||
|
||||
search_chunk_size_for_mla<config_size>
|
||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
decoder_chunk_size_device.data<int>(),
|
||||
bsz,
|
||||
set_chunk_size,
|
||||
block_size,
|
||||
sm_cout);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
auto decoder_chunk_size_cpu =
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
bsz,
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
@@ -444,12 +258,16 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
|
||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
// Clear buffer
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
@@ -457,9 +275,54 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
encoder_num_blocks_x_cpu =
|
||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -469,20 +332,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
"max_len_kv_cpu"
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
|
@@ -217,7 +217,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -235,7 +235,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -278,7 +278,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -296,7 +296,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -400,7 +400,7 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
// load v_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -418,7 +418,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal k_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
// layout
|
||||
@@ -466,7 +466,7 @@ __global__ void append_cache_kv_c8(
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -485,7 +485,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -614,7 +614,7 @@ __global__ void append_cache_kv_c4(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -632,7 +632,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
|
||||
@@ -685,7 +685,7 @@ __global__ void append_cache_kv_c4(
|
||||
tid % 2 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 rows
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -704,7 +704,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
|
@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
@@ -105,7 +105,7 @@ void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
@@ -255,8 +255,7 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size);
|
||||
const int estimate_total_token_nums);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
@@ -299,23 +298,14 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||
const int layer_id);
|
||||
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -416,8 +406,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input);
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
@@ -475,18 +465,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks_device,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -571,7 +566,6 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -623,8 +617,6 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
void clear_ipc_handles(int64_t _fa);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -1008,7 +1000,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
"per token per block quant and padding transpose scale");
|
||||
"per token per block quant and padding tranpose scale");
|
||||
|
||||
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
|
||||
py::arg("recv_expert_count"), py::arg("block_size"),
|
||||
@@ -1051,7 +1043,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_expert_ffn_wint2.cu
|
||||
* moe/fused_moe/moe_ffn_wint2.cu
|
||||
* moe_expert_ffn_wint2
|
||||
*/
|
||||
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
|
||||
@@ -1231,8 +1223,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
|
@@ -122,14 +122,10 @@ void register_graph_buffers(fptr_t _fa,
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
void clear_ipc_handles(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
fa->clear_ipc_handles();
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
|
@@ -303,7 +303,7 @@ class CustomAllreduce {
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
@@ -517,15 +517,10 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
void clear_ipc_handles(){
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
ipc_handles_.clear();
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
||||
|
@@ -89,11 +89,11 @@ public:
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of warp-level GEMM operations per load for B
|
||||
/// Number of warp-level GEMM oeprations per load for B
|
||||
static constexpr int kWarpGemmIterationsPerLoadForB =
|
||||
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
|
||||
|
@@ -117,7 +117,7 @@ class LeftGELUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to internal compute numeric type
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to internal compute numeric type
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -92,7 +92,7 @@ class DualMmaBase {
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);
|
||||
|
||||
|
@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
|
||||
// adding elementwise beta, we keep this here for future usage beta_ =
|
||||
// adding elementwise beta, we keep this here for future useage beta_ =
|
||||
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
|
||||
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
|
||||
// iterator_C_.clear_mask();
|
||||
|
@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
@@ -64,7 +64,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
|
@@ -133,7 +133,7 @@ public:
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
|
@@ -509,7 +509,7 @@ public:
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
// TODO(wangbojun) lds_converter can be remove for int8 B input
|
||||
// TOOD(wangbojun) lds_converter can be remove for int8 B input
|
||||
typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
|
||||
|
@@ -96,7 +96,7 @@ public:
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
|
||||
static constexpr int kNumKIterationsPerWarpBLoad =
|
||||
|
@@ -646,7 +646,7 @@ public:
|
||||
// );
|
||||
// }
|
||||
}
|
||||
// TODO(wangbojun) lds_converter can be remove for int8 B input
|
||||
// TOOD(wangbojun) lds_converter can be remove for int8 B input
|
||||
// int4
|
||||
// typename TransformBAfterLDS::result_type converted_frag_B =
|
||||
// lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
|
@@ -59,15 +59,6 @@ inline uint32_t get_cascade_attention_num_threads() {
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
||||
inline int get_mla_dec_chunk_size(int bsz) {
|
||||
static const char* mla_dec_chunk_size_env =
|
||||
std::getenv("FLAGS_mla_dec_chunk_size");
|
||||
static const int mla_dec_chunk_size =
|
||||
mla_dec_chunk_size_env == nullptr
|
||||
? -1
|
||||
: std::stoi(std::string(mla_dec_chunk_size_env));
|
||||
return bsz > 1 ? mla_dec_chunk_size : 64;
|
||||
}
|
||||
|
@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
|
@@ -14,8 +14,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "glog/logging.h"
|
||||
#endif
|
||||
@@ -153,34 +151,6 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -223,13 +193,11 @@ public:
|
||||
typedef uint8_t data_t;
|
||||
};
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||
public:
|
||||
typedef __nv_fp8_e4m3 DataType;
|
||||
typedef paddle::float8_e4m3fn data_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
@@ -595,36 +563,3 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
inline bool GetMlaUseTensorcore() {
|
||||
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
|
||||
const bool mla_use_tensorcore =
|
||||
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
||||
return mla_use_tensorcore;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
@@ -171,7 +171,7 @@ struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
@@ -18,6 +18,7 @@
|
||||
#include "iomanip"
|
||||
#include <nvml.h>
|
||||
#include <iostream>
|
||||
#include <nvml.h>
|
||||
// #define PRINT_GPU_MEMORY
|
||||
// 函数用于获取 NVIDIA GPU 显存信息
|
||||
bool getNvidiaGPUMemoryUsage(int callLine) {
|
||||
|
@@ -30,12 +30,10 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@@ -65,8 +63,6 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -51,8 +51,6 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -70,6 +70,7 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -77,8 +78,9 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -95,12 +97,14 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_seq_len, 64);
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
@@ -123,14 +127,14 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.chunk_size_device =
|
||||
const_cast<int*>(decoder_chunk_size_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
@@ -147,6 +151,7 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
@@ -171,6 +176,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -178,8 +184,9 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -203,6 +210,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -210,8 +218,9 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -47,6 +47,7 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -54,8 +55,9 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -128,13 +128,12 @@ struct CollectiveMainloop {
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
// IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
IdType const* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -145,7 +144,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
// int chunk_size;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
@@ -161,13 +160,12 @@ struct CollectiveMainloop {
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
// IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
IdType* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -178,7 +176,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
// int chunk_size;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
@@ -200,13 +198,12 @@ struct CollectiveMainloop {
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
// const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
const_cast<IdType*>(args.chunk_size_device),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
@@ -217,7 +214,7 @@ struct CollectiveMainloop {
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
// args.chunk_size,
|
||||
args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
@@ -284,9 +281,9 @@ struct CollectiveMainloop {
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
@@ -325,9 +322,9 @@ struct CollectiveMainloop {
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
|
@@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
@@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
|
@@ -62,12 +62,13 @@ struct Params {
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
@@ -75,7 +76,7 @@ struct Params {
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
@@ -95,7 +96,9 @@ struct Params {
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
@@ -115,7 +118,7 @@ struct Params {
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
@@ -134,7 +137,6 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
const int chunk_size = mainloop_params.chunk_size_device[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
@@ -203,10 +205,58 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -259,12 +309,76 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -315,15 +429,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
chunk_size,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
@@ -346,12 +460,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
@@ -362,6 +476,7 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
@@ -385,9 +500,13 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
|
||||
// by the graph.
|
||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
@@ -398,38 +517,37 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
KernelTraits::HEAD_DIM_VO>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
params.chunk_size_device,
|
||||
reinterpret_cast<NV_TYPE *>(params.O),
|
||||
params.q_num_head,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
|
@@ -249,16 +249,18 @@ struct prefill_softmax_state_t {
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ batch_id_per_token,
|
||||
const int * __restrict__ chunk_size_device,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
@@ -269,15 +271,13 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
|
||||
if (bid == -1) continue;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
|
@@ -383,7 +383,7 @@ __global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attenti
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
inline __device__ float calculate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
|
||||
const int32_t bi = blockIdx.z;
|
||||
@@ -524,7 +524,7 @@ __global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_a
|
||||
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
|
||||
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
|
||||
|
||||
float inv_global_exp_sum = calculate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
|
||||
|
||||
using T_vec = Vec<cuteType, kNReducePacksize>;
|
||||
|
@@ -40,7 +40,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
if (seq_len == 0) return;
|
||||
|
||||
const int remain_tokens = seq_len - block_idx;
|
||||
const int ramian_tokens = seq_len - block_idx;
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
|
||||
@@ -51,7 +51,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < remain_tokens) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ __global__ void write_encoder_cachekv_c16(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < remain_tokens) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
@@ -50,14 +50,14 @@ __global__ void get_kv_from_cache_c16_kernel(
|
||||
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
|
||||
|
||||
|
||||
const int remain_tokens = seq_len - base_token_idx;
|
||||
const int ramian_tokens = seq_len - base_token_idx;
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < remain_tokens) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
@@ -67,7 +67,7 @@ __global__ void get_kv_from_cache_c16_kernel(
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < remain_tokens) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
@@ -872,14 +872,16 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream) {
|
||||
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
||||
|
||||
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
|
||||
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
|
||||
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
|
||||
constexpr int kThreads = 128;
|
||||
if (FLAGS_hardamard_use_diagonal_block_matrix) {
|
||||
const int VecSize = hadamard_block_size / kThreads;
|
||||
const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1
|
||||
const int logN = int(ceil(std::log2(kThreads * VecSize)));
|
||||
constexpr int kNChunks = 1;
|
||||
DISPATCH_SP_VS(VecSize, VEC_SIZE, {
|
||||
@@ -989,7 +991,6 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::float16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1008,7 +1009,6 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1027,7 +1027,6 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::bfloat16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1046,7 +1045,6 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
|
@@ -32,6 +32,5 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream);
|
||||
|
@@ -236,7 +236,7 @@ public:
|
||||
num_experts, k, stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>(
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||
num_experts, k, group_moe, stream);
|
||||
@@ -248,7 +248,7 @@ public:
|
||||
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||
false, stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
initialize_moe_routing_kernelLauncher<T>::run(
|
||||
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||
hidden_size, k, stream);
|
||||
@@ -335,14 +335,14 @@ public:
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
fc2_result, output_, fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
// fc2_result,
|
||||
fc1_out, output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
|
@@ -1139,7 +1139,9 @@ void topk_gating_softmax_launcher_helper(const T* input,
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
struct topk_gating_softmax_kernelLauncher{
|
||||
|
||||
static void run(const T* input,
|
||||
const T* gating_correction_bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
@@ -1219,6 +1221,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
@@ -1313,7 +1316,9 @@ __global__ void initialize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T, typename OutT = T>
|
||||
void initialize_moe_routing_kernelLauncher(
|
||||
struct initialize_moe_routing_kernelLauncher{
|
||||
|
||||
static void run(
|
||||
const T* unpermuted_input,
|
||||
OutT* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
@@ -1356,6 +1361,7 @@ void initialize_moe_routing_kernelLauncher(
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
@@ -1466,7 +1472,8 @@ __global__ void finalize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void finalize_moe_routing_kernelLauncher(
|
||||
struct finalize_moe_routing_kernelLauncher{
|
||||
static void run(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
@@ -1498,4 +1505,5 @@ void finalize_moe_routing_kernelLauncher(
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
};
|
||||
} // namespace phi
|
||||
|
@@ -36,9 +36,6 @@ void MoeDispatchKernel(
|
||||
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||
using namespace phi;
|
||||
|
||||
if (num_rows == 0){
|
||||
return;
|
||||
}
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -83,7 +80,7 @@ void MoeDispatchKernel(
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
// (TODO: check fill sucess ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
@@ -103,7 +100,7 @@ void MoeDispatchKernel(
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher(
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||
: nullptr,
|
||||
@@ -117,13 +114,13 @@ void MoeDispatchKernel(
|
||||
|
||||
if (w4a8_in_scale) {
|
||||
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
|
||||
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
|
||||
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
@@ -131,7 +128,7 @@ void MoeDispatchKernel(
|
||||
hidden_size, moe_topk, stream);
|
||||
}
|
||||
} else {
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
initialize_moe_routing_kernelLauncher<data_t>::run(
|
||||
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), nullptr,
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
@@ -188,15 +185,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto expert_idx_per_token =
|
||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0){
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||
|
@@ -35,8 +35,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
const int estimate_total_token_nums) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -292,7 +291,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
|
||||
stream
|
||||
);
|
||||
@@ -342,7 +340,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream
|
||||
);
|
||||
@@ -406,15 +403,13 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
const int estimate_total_token_nums) {
|
||||
|
||||
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
if(permute_input.numel() == 0){
|
||||
return ffn_out;
|
||||
}
|
||||
|
||||
switch (t_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
@@ -429,8 +424,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
estimate_total_token_nums);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -445,8 +439,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
estimate_total_token_nums);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
@@ -465,8 +458,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
const int estimate_total_token_nums) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
@@ -478,8 +470,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size)};
|
||||
estimate_total_token_nums)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
@@ -494,8 +485,7 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
const int estimate_total_token_nums) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
@@ -509,7 +499,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
const int estimate_total_token_nums) {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
@@ -565,8 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
* Affects activation function implementation
|
||||
* - estimate_total_token_nums: estimate total token numbers
|
||||
* - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization
|
||||
*
|
||||
* Note:
|
||||
* - w4a8 mode requires additional workspace memory allocation
|
||||
@@ -583,7 +571,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||
|
@@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
finalize_moe_routing_kernelLauncher<data_t>::run(
|
||||
ffn_out.data<data_t>(), output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
|
||||
@@ -59,10 +59,6 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
if(num_rows == 0){
|
||||
return output;
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
|
@@ -22,18 +22,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -59,12 +64,9 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
// NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage
|
||||
// int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
//
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
@@ -94,6 +96,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
@@ -101,8 +104,9 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_chunk_size_device,
|
||||
decoder_num_blocks_data,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -141,18 +145,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -199,18 +208,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_chunk_size_device,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -240,18 +254,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_chunk_size_device,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -288,18 +307,23 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_chunk_size_device_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
@@ -337,18 +361,23 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_chunk_size_device_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
@@ -386,18 +415,23 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"batch_id_per_token",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_chunk_size_device",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
|
@@ -26,7 +26,6 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
PD_CHECK(input_shape.size() == 2);
|
||||
@@ -49,7 +48,6 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
@@ -78,7 +76,6 @@ PD_BUILD_STATIC_OP(noaux_tc)
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"renormalize: bool",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user