mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 14:52:33 +08:00
Compare commits
1 Commits
release/2.
...
v2.0.0
Author | SHA1 | Date | |
---|---|---|---|
![]() |
aa9d17b5ae |
7
.flake8
7
.flake8
@@ -1,7 +0,0 @@
|
||||
[flake8]
|
||||
ignore = E203, E402, E501, E731, E741, W503, W605, E722, E231, W604, E702, E226, E221, E713, E271
|
||||
max-line-length = 119
|
||||
|
||||
# E402: module level import not at top of file
|
||||
per-file-ignores =
|
||||
__init__.py:F401,F403,E402
|
48
.github/workflows/Codestyle-Check.yml
vendored
48
.github/workflows/Codestyle-Check.yml
vendored
@@ -1,48 +0,0 @@
|
||||
name: Codestyle-Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["develop"]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: Pre Commit
|
||||
if: ${{ github.repository_owner == 'PaddlePaddle' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_ID: ${{ github.event.pull_request.number }}
|
||||
BRANCH: develop
|
||||
|
||||
steps:
|
||||
- name: Cleanup
|
||||
run: |
|
||||
rm -rf * .[^.]*
|
||||
|
||||
- name: Checkout base repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.base.ref }}
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Merge PR to test branch
|
||||
run: |
|
||||
git fetch origin pull/${PR_ID}/merge
|
||||
git checkout -b test FETCH_HEAD
|
||||
|
||||
- name: Setup python3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install pre-commit==4.2.0 cpplint==1.6.0 clang-format==13.0.0
|
||||
|
||||
- name: Check pre-commit
|
||||
env:
|
||||
SKIP_CLANG_TIDY_CHECK: "ON"
|
||||
run: |
|
||||
set +e
|
||||
bash -x tools/codestyle/pre_commit.sh;EXCODE=$?
|
||||
exit $EXCODE
|
174
.github/workflows/_build_linux.yml
vendored
174
.github/workflows/_build_linux.yml
vendored
@@ -1,174 +0,0 @@
|
||||
name: FastDeploy Linux GPU Build Task
|
||||
description: "FastDeploy packages build and upload"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
COMPILE_ARCH:
|
||||
description: "Build GPU Archs"
|
||||
required: true
|
||||
type: string
|
||||
default: "80,90"
|
||||
WITH_NIGHTLY_BUILD:
|
||||
description: "Enable nightly build mode (e.g. add date suffix to version)"
|
||||
required: false
|
||||
type: string
|
||||
default: "ON"
|
||||
FD_VERSION:
|
||||
description: "FastDeploy Package Version"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
UPLOAD:
|
||||
description: "Upload Package"
|
||||
required: false
|
||||
type: string
|
||||
default: "ON"
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
outputs:
|
||||
wheel_path:
|
||||
description: "Output path of the generated wheel"
|
||||
value: ${{ jobs.fd-build.outputs.wheel_path }}
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-h1z1-4Cards]
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
IS_PR: ${{ github.event_name == 'pull_request' }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
- name: FastDeploy Build
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
compile_arch: ${{ inputs.COMPILE_ARCH }}
|
||||
fd_version: ${{ inputs.FD_VERSION }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
run: |
|
||||
set -x
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | cut -d'-' -f2)
|
||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
|
||||
CACHE_DIR=${CACHE_DIR:-${{ github.workspace }}}
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host \
|
||||
--cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "COMPILE_ARCH=${compile_arch}" \
|
||||
-e "FD_VERSION=${fd_version}" \
|
||||
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
if [[ -n "${FD_VERSION}" ]]; then
|
||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||
echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}"
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g")
|
||||
echo "Git Commit Time: $GIT_COMMIT_TIME"
|
||||
echo "Date Only: $DATE_ONLY"
|
||||
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
|
||||
fi
|
||||
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
|
||||
pip config set install.trusted-host pip.baidu.com
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install wheel
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
# 编译RDMA
|
||||
export ENABLE_FD_RDMA=1
|
||||
bash build.sh 1 python false [${COMPILE_ARCH}]
|
||||
ls ./dist/*.whl
|
||||
'
|
||||
- name: Package Upload
|
||||
id: set_output
|
||||
env:
|
||||
compile_arch: ${{ inputs.COMPILE_ARCH }}
|
||||
run: |
|
||||
set -x
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]];then
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python --version
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
cd FastDeploy/dist/
|
||||
matches=($(ls fastdeploy*.whl))
|
||||
if [ ${#matches[@]} -ne 1 ]; then
|
||||
echo "Error: Found ${#matches[@]} matching files, expected exactly 1"
|
||||
exit 1
|
||||
fi
|
||||
fd_wheel_name=${matches[0]}
|
||||
echo "Found: $fd_wheel_name"
|
||||
tree -L 3
|
||||
python ${push_file} fastdeploy*.whl ${target_path}
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "wheel_path=${WHEEL_PATH}" >> $GITHUB_OUTPUT
|
78
.github/workflows/_clone_linux.yml
vendored
78
.github/workflows/_clone_linux.yml
vendored
@@ -1,78 +0,0 @@
|
||||
name: FastDeploy Code Clone
|
||||
description: "FastDeploy clone and upload"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
bos_dir:
|
||||
type: string
|
||||
required: false
|
||||
default: 'FastDeploy'
|
||||
outputs:
|
||||
repo_archive_url:
|
||||
description: "Compressed source code archive."
|
||||
value: ${{ jobs.code-clone.outputs.repo_archive_url }}
|
||||
jobs:
|
||||
code-clone:
|
||||
runs-on:
|
||||
group: HK-Clone
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request'
|
||||
&& github.event.pull_request.base.ref
|
||||
|| github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Merge PR (if needed)
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
run: |
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
echo "Fetching and merging PR..."
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge --no-ff pr/${{ github.event.pull_request.number }}
|
||||
echo "PR Branch log "
|
||||
git log --oneline -n 5 pr/${{ github.event.pull_request.number }}
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: paddle
|
||||
SK: paddle
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]];then
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
46
.github/workflows/ci.yml
vendored
46
.github/workflows/ci.yml
vendored
@@ -2,9 +2,7 @@ name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
branches: [ develop ]
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -29,11 +27,9 @@ jobs:
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
@@ -42,7 +38,7 @@ jobs:
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
@@ -60,36 +56,14 @@ jobs:
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [ "${last_char}" = "1" ]; then
|
||||
gpu_id=2
|
||||
DEVICES="2,3"
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id=0
|
||||
DEVICES="0,1"
|
||||
gpu_id="0"
|
||||
fi
|
||||
|
||||
FLASK_PORT=$((41068 + gpu_id * 100))
|
||||
FD_API_PORT=$((41088 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((41058 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((41078 + gpu_id * 100))
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
@@ -102,8 +76,8 @@ jobs:
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c "
|
||||
--gpus device=${gpu_id} ${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci.sh
|
||||
"
|
||||
"
|
84
.github/workflows/ci_iluvatar.yml
vendored
84
.github/workflows/ci_iluvatar.yml
vendored
@@ -1,84 +0,0 @@
|
||||
name: CI_ILUVATAR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-iluvatar-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
CI_ILUVATAR:
|
||||
runs-on: [self-hosted, IXUCA]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host --pid=host --cap-add=ALL --privileged --shm-size=64G \
|
||||
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/data1/fastdeploy:/data1/fastdeploy" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_iluvatar.sh
|
||||
"
|
87
.github/workflows/ci_xpu.yml
vendored
87
.github/workflows/ci_xpu.yml
vendored
@@ -1,87 +0,0 @@
|
||||
name: CI_XPU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-xpu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
CI_XPU:
|
||||
runs-on: [self-hosted, XPU-P800-8Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd3:/ssd3" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_xpu.sh
|
||||
"
|
6
.github/workflows/gh-pages.yml
vendored
6
.github/workflows/gh-pages.yml
vendored
@@ -3,6 +3,8 @@ name: Deploy GitHub Pages
|
||||
on:
|
||||
push:
|
||||
branches: [ develop ]
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -19,6 +21,4 @@ jobs:
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
|
||||
mkdocs gh-deploy --force --remote-name origin
|
||||
run: mkdocs gh-deploy --force --remote-name origin
|
||||
|
35
.github/workflows/pr_build_and_test.yml
vendored
35
.github/workflows/pr_build_and_test.yml
vendored
@@ -1,35 +0,0 @@
|
||||
name: PR Build and Test
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches: [develop, release/**]
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-${{ github.workflow }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
clone:
|
||||
name: FD-Clone-Linux
|
||||
uses: ./.github/workflows/_clone_linux.yml
|
||||
|
||||
build:
|
||||
name: FD-Build-Linux
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
FD_VERSION: "0.0.0"
|
||||
|
||||
resultshow:
|
||||
name: Use Build Output
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}"
|
2
.gitignore
vendored
2
.gitignore
vendored
@@ -162,5 +162,3 @@ custom_ops/tmp*
|
||||
build
|
||||
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
@@ -3,30 +3,20 @@ default_install_hook_types:
|
||||
- commit-msg
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- commit-msg
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
- repo: https://github.com/psf/black.git
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: \.(py|pyi)$
|
||||
additional_dependencies: [toml]
|
||||
# 自动排序
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
|
||||
args: [--output-format, github, --fix, --line-length=120]
|
||||
# # 拼写检查
|
||||
# - repo: https://github.com/codespell-project/codespell
|
||||
# rev: v2.4.1
|
||||
@@ -34,13 +24,26 @@ repos:
|
||||
# - id: codespell
|
||||
# additional_dependencies: ['tomli']
|
||||
# args: ['--toml', 'pyproject.toml']
|
||||
# 自动排序
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
args: ["-d", "MD029,MD031", fix]
|
||||
args: [fix]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
|
@@ -8,17 +8,14 @@
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@@ -26,10 +23,6 @@
|
||||
|
||||
## News
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
**[2025-06] 🔥 Released FastDeploy v2.0:** Supports inference and deployment for ERNIE 4.5. Furthermore, we open-source an industrial-grade PD disaggregation with context caching, dynamic role switching for effective resource utilization to further enhance inference performance for MoE models.
|
||||
|
||||
## About
|
||||
|
@@ -41,10 +41,7 @@ python -m pip install -r requirements.txt
|
||||
--metric-percentiles 80,95,99,99.9,99.95,99.99:性能结果中展示的性能指标分位值
|
||||
--num-prompts 1:总计发送多少条请求
|
||||
--max-concurrency 1:压测并发数
|
||||
--save-result:开启结果保存,结果文件会存入json,默认False不保存
|
||||
--debug:开启debug模式,逐条打印payload和output内容,默认False
|
||||
--shuffle:是否打乱数据集,默认False不打乱
|
||||
--seed:打乱数据集时的随机种子,默认0
|
||||
--save-result:开启结果保存,结果文件会存入json
|
||||
```
|
||||
|
||||
##### /v1/chat/completions接口压测单条数据调试
|
||||
@@ -108,30 +105,3 @@ python benchmark_serving.py \
|
||||
--save-result > infer_log.txt 2>&1 &
|
||||
```
|
||||
|
||||
### 投机解码性能测试工具
|
||||
|
||||
#### 使用方式:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mtp.py \
|
||||
--host 127.0.0.1 --port 8000 \
|
||||
--max-concurrency 16 32 64 96 --num-prompts 256 \
|
||||
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
|
||||
--s_itl-base-model 15.88 22.84 16.47 16.93 \
|
||||
--dataset-name EBChat \
|
||||
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
|
||||
```
|
||||
|
||||
#### 参数说明
|
||||
|
||||
```bash
|
||||
--host:服务ip地址,用于组url
|
||||
--port:服务HTTP端口,用于组url
|
||||
--max-concurrency:测试并发数
|
||||
--num-prompts:总计发送多少条请求
|
||||
--acceptance-rate:投机解码的模拟接受率
|
||||
--draft-token-steps:投机解码的步数
|
||||
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
||||
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
||||
--dataset-path:测试数据集路径
|
||||
```
|
||||
|
@@ -29,14 +29,13 @@ from typing import Optional
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""Input for requesting LLMs via API"""
|
||||
|
||||
no: int
|
||||
prompt: str
|
||||
history_QA: Optional[dict]
|
||||
hyper_parameters: dict
|
||||
@@ -50,14 +49,11 @@ class RequestFuncInput:
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
debug: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""Output for requesting LLMs via API"""
|
||||
|
||||
no: int = 0
|
||||
generated_text: str = ""
|
||||
reasoning_content: str = ""
|
||||
success: bool = False
|
||||
@@ -68,7 +64,7 @@ class RequestFuncOutput:
|
||||
itl: list = field(default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
prompt_tokens: int = 0 # 推理侧返回输入token数
|
||||
prompt_tokens: int = 0 # 推理侧返回输入token数
|
||||
error: str = ""
|
||||
|
||||
|
||||
@@ -78,19 +74,22 @@ async def async_request_eb_openai_chat_completions(
|
||||
) -> RequestFuncOutput:
|
||||
"""Request an LLM using EB OpenAI"""
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
|
||||
assert api_url.endswith(
|
||||
("completions", "profile")
|
||||
), "OpenAI Chat Completions API URL must end with 'completions'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"model": "default",
|
||||
"messages": request_func_input.history_QA,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True,
|
||||
"continuous_usage_stats": True
|
||||
},
|
||||
}
|
||||
# 超参由yaml传入
|
||||
@@ -98,10 +97,6 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
if request_func_input.debug:
|
||||
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
@@ -109,20 +104,21 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = 0
|
||||
output.no = request_func_input.no
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, type(chunk))
|
||||
timestamp = time.perf_counter()
|
||||
@@ -136,20 +132,21 @@ async def async_request_eb_openai_chat_completions(
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
# cached_tokens
|
||||
output.prompt_len = (
|
||||
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
||||
)
|
||||
output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
output.generated_text += content or ""
|
||||
output.reasoning_content += reason_content or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage", {}):
|
||||
output.output_tokens = usage.get("completion_tokens", 0)
|
||||
output.prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -162,12 +159,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
error_text = await response.text()
|
||||
print(
|
||||
"####error response:",
|
||||
error_text,
|
||||
"####payload:",
|
||||
payload,
|
||||
)
|
||||
print("####error response:", error_text, "####payload:", payload)
|
||||
output.error = error_text or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
@@ -181,8 +173,6 @@ async def async_request_eb_openai_chat_completions(
|
||||
f.write(str(output) + "\n")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
if request_func_input.debug:
|
||||
print("#####final_output:", output)
|
||||
return output
|
||||
|
||||
|
||||
@@ -196,14 +186,15 @@ async def async_request_eb_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"model": "default",
|
||||
"prompt": request_func_input.prompt,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True,
|
||||
"continuous_usage_stats": True
|
||||
},
|
||||
}
|
||||
# 超参由yaml传入
|
||||
@@ -211,25 +202,19 @@ async def async_request_eb_openai_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
if request_func_input.debug:
|
||||
print("payload:", json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.no = request_func_input.no
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
@@ -237,10 +222,10 @@ async def async_request_eb_openai_completions(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, chunk.usage)
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
@@ -250,40 +235,35 @@ async def async_request_eb_openai_completions(
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = timestamp - st
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += text or ""
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.prompt_tokens = usage.get("prompt_tokens")
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
|
||||
)
|
||||
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
|
||||
if output.generated_text == "":
|
||||
output.success = False
|
||||
output.error = "No generated text found!"
|
||||
else:
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -292,9 +272,6 @@ async def async_request_eb_openai_completions(
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if request_func_input.debug:
|
||||
print(f"final_output:{output}")
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
@@ -308,7 +285,8 @@ async def async_request_tgi(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
@@ -355,7 +333,8 @@ async def async_request_tgi(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output.arrival_time.append(data["arrival_time"])
|
||||
@@ -384,7 +363,8 @@ async def async_request_trt_llm(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
@@ -409,7 +389,8 @@ async def async_request_trt_llm(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
output.generated_text += data["text_output"]
|
||||
@@ -421,7 +402,8 @@ async def async_request_trt_llm(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -446,7 +428,8 @@ async def async_request_deepspeed_mii(
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""Request an LLM using Deepspeed MII"""
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
@@ -464,16 +447,19 @@ async def async_request_deepspeed_mii(
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url, json=payload) as response:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as response:
|
||||
if response.status == 200:
|
||||
parsed_resp = await response.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
if "choices" in parsed_resp:
|
||||
output.generated_text = parsed_resp["choices"][0]["text"]
|
||||
output.generated_text = parsed_resp["choices"][0][
|
||||
"text"]
|
||||
elif "text" in parsed_resp:
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
else:
|
||||
output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
|
||||
output.error = ("Unexpected response format: "
|
||||
"neither 'choices' nor 'text' found")
|
||||
output.success = False
|
||||
output.success = True
|
||||
else:
|
||||
@@ -499,22 +485,26 @@ async def async_request_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
# "temperature": 0.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
# "stream_options": {
|
||||
#"stream_options": {
|
||||
# "include_usage": True,
|
||||
# },
|
||||
#},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@@ -523,7 +513,8 @@ async def async_request_openai_completions(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
@@ -531,7 +522,8 @@ async def async_request_openai_completions(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, type(chunk))
|
||||
data = json.loads(chunk)
|
||||
@@ -552,19 +544,21 @@ async def async_request_openai_completions(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
|
||||
)
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
@@ -587,24 +581,25 @@ async def async_request_openai_audio(
|
||||
"""Request an LLM using OpenAI"""
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations")
|
||||
), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
@@ -619,9 +614,9 @@ async def async_request_openai_audio(
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
@@ -633,20 +628,24 @@ async def async_request_openai_audio(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, data=form, headers=headers) as response:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
@@ -654,11 +653,13 @@ async def async_request_openai_audio(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -692,11 +693,8 @@ ASYNC_REQUEST_FUNCS = {
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k
|
||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v
|
||||
in (
|
||||
async_request_openai_completions,
|
||||
async_request_eb_openai_chat_completions,
|
||||
)
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_eb_openai_chat_completions)
|
||||
]
|
||||
|
||||
|
@@ -26,10 +26,10 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from PIL import Image
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ class SampleRequest:
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
no: int
|
||||
prompt: Union[str, Any]
|
||||
history_QA: Union[str, Any]
|
||||
json_data: Optional[dict]
|
||||
@@ -49,7 +48,6 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
"""BenchmarkDataset"""
|
||||
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
@@ -57,7 +55,6 @@ class BenchmarkDataset(ABC):
|
||||
self,
|
||||
dataset_path: Optional[str] = None,
|
||||
random_seed: int = DEFAULT_SEED,
|
||||
shuffle: bool = False,
|
||||
hyperparameter_path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -71,9 +68,9 @@ class BenchmarkDataset(ABC):
|
||||
self.dataset_path = dataset_path
|
||||
# Set the random seed, ensuring that a None value is replaced with the
|
||||
# default seed.
|
||||
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
||||
self.random_seed = (random_seed
|
||||
if random_seed is not None else self.DEFAULT_SEED)
|
||||
self.data = None
|
||||
self.shuffle = shuffle
|
||||
self.hyperparameter_path = hyperparameter_path
|
||||
self.hyperparameters = {}
|
||||
|
||||
@@ -88,7 +85,8 @@ class BenchmarkDataset(ABC):
|
||||
NotImplementedError: If a subclass does not implement this method.
|
||||
"""
|
||||
# TODO (jenniferzhao): add support for downloading data
|
||||
raise NotImplementedError("load_data must be implemented in subclasses.")
|
||||
raise NotImplementedError(
|
||||
"load_data must be implemented in subclasses.")
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, num_requests: int) -> list[SampleRequest]:
|
||||
@@ -107,7 +105,8 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||
num_requests: int) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
number.
|
||||
@@ -118,9 +117,11 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = random.choices(requests, k=num_requests - len(requests))
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
|
||||
|
||||
def is_valid_sequence(
|
||||
@@ -140,12 +141,14 @@ def is_valid_sequence(
|
||||
"""
|
||||
# Check for invalid conditions
|
||||
prompt_too_short = prompt_len < min_len
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len
|
||||
< min_len)
|
||||
prompt_too_long = prompt_len > max_prompt_len
|
||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||
|
||||
# Return True if none of the invalid conditions are met
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long
|
||||
or combined_too_long)
|
||||
|
||||
|
||||
def process_image(image: Any) -> Mapping[str, Any]:
|
||||
@@ -168,25 +171,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(image, dict) and "bytes" in image:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
if isinstance(image, dict) and 'bytes' in image:
|
||||
image = Image.open(BytesIO(image['bytes']))
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGB")
|
||||
with io.BytesIO() as image_data:
|
||||
image.save(image_data, format="JPEG")
|
||||
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(image, str):
|
||||
image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
|
||||
image_url = (image if image.startswith(
|
||||
("http://", "file://")) else f"file://{image}")
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
|
||||
)
|
||||
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes.")
|
||||
|
||||
|
||||
class EBDataset(BenchmarkDataset):
|
||||
@@ -213,10 +219,6 @@ class EBDataset(BenchmarkDataset):
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = [json.loads(i.strip()) for i in f.readlines()]
|
||||
|
||||
if self.shuffle:
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
num_requests: int,
|
||||
@@ -227,7 +229,6 @@ class EBDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -241,17 +242,15 @@ class EBDataset(BenchmarkDataset):
|
||||
new_output_len = int(entry["max_dec_len"])
|
||||
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
prompt=prompt,
|
||||
prompt_len=self.prompt_len,
|
||||
history_QA=[],
|
||||
expected_output_len=new_output_len,
|
||||
)
|
||||
)
|
||||
cnt += 1
|
||||
))
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
@@ -262,7 +261,6 @@ class EBChatDataset(BenchmarkDataset):
|
||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||
sample requests based on conversation turns.
|
||||
"""
|
||||
|
||||
prompt_len: int
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
@@ -276,10 +274,6 @@ class EBChatDataset(BenchmarkDataset):
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = [json.loads(i.strip()) for i in f.readlines()]
|
||||
|
||||
if self.shuffle:
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
num_requests: int,
|
||||
@@ -290,7 +284,6 @@ class EBChatDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -300,18 +293,17 @@ class EBChatDataset(BenchmarkDataset):
|
||||
new_output_len = int(entry.get("max_tokens", 12288))
|
||||
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
json_data=json_data,
|
||||
prompt=prompt,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
expected_output_len=new_output_len,
|
||||
)
|
||||
)
|
||||
cnt += 1
|
||||
))
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
@@ -1,178 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from benchmark_dataset import EBChatDataset, EBDataset
|
||||
from benchmark_serving import benchmark
|
||||
|
||||
|
||||
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
|
||||
dataset_mapping = {
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}") from err
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
|
||||
selected_percentile_metrics = ["s_itl"]
|
||||
selected_percentiles = []
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend="openai-chat",
|
||||
api_url=f"{base_url}/v1/chat/completions",
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
model_name="default",
|
||||
input_requests=input_requests,
|
||||
hyper_parameters={},
|
||||
logprobs=None,
|
||||
request_rate=float("inf"),
|
||||
burstiness=1.0,
|
||||
disable_tqdm=disable_tqdm,
|
||||
profile=False,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
ignore_eos=False,
|
||||
goodput_config_dict=None,
|
||||
max_concurrency=max_concurrency,
|
||||
lora_modules=None,
|
||||
extra_body=None,
|
||||
)
|
||||
)
|
||||
|
||||
record = {
|
||||
"mean_s_itl_ms": results["mean_s_itl_ms"],
|
||||
}
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
||||
|
||||
tmp = 0.0
|
||||
for i in range(draft_token_step):
|
||||
tmp += pow(acceptance_rate, i + 1)
|
||||
|
||||
r_ac = tmp / (1 + tmp)
|
||||
|
||||
return t_ori / ((1 - r_ac) * t_mtp)
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
|
||||
|
||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
send_one_batch(
|
||||
base_url,
|
||||
max_concurrency,
|
||||
input_requests[0:max_concurrency],
|
||||
True,
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||
|
||||
metric_header = "Speed up"
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
for draft_token_step in args.draft_token_steps:
|
||||
speedup = calculate_speedup(
|
||||
args.acceptance_rate,
|
||||
draft_token_step,
|
||||
s_itl,
|
||||
record["mean_s_itl_ms"],
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="8000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16, 32),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acceptance-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-token-steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--s_itl-base-model",
|
||||
type=float,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="EBChat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
File diff suppressed because it is too large
Load Diff
@@ -24,11 +24,9 @@ import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any],
|
||||
) -> list:
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any]) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
@@ -56,10 +54,12 @@ def convert_to_pytorch_benchmark_format(
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||
tp = record["benchmark"]["extra_info"]["args"].get(
|
||||
"tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
record["benchmark"]["extra_info"]["args"][
|
||||
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
|
||||
records.append(record)
|
||||
|
||||
@@ -68,7 +68,6 @@ def convert_to_pytorch_benchmark_format(
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
"""InfEncoder"""
|
||||
|
||||
def clear_inf(self, o: Any):
|
||||
"""clear_inf"""
|
||||
if isinstance(o, dict):
|
||||
@@ -88,3 +87,4 @@ def write_to_json(filename: str, records: list) -> None:
|
||||
"""write_to_json"""
|
||||
with open(filename, "w") as f:
|
||||
json.dump(records, f, cls=InfEncoder)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -3,4 +3,3 @@ tqdm
|
||||
numpy
|
||||
Pillow
|
||||
pyyaml
|
||||
requests
|
||||
|
@@ -7,4 +7,4 @@ tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint4
|
||||
reasoning_parser: ernie-45-vl
|
||||
reasoning_parser: ernie-45-vl
|
@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
|
||||
pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
||||
engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
pd_comm_port: "2333"
|
@@ -3,4 +3,3 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
quantization: wint4
|
||||
|
@@ -10,4 +10,4 @@ engine_worker_queue_port: 6677
|
||||
num_gpu_blocks_override: 1024
|
||||
cache_transfer_protocol: "rdma"
|
||||
rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
|
||||
pd_comm_port: "2334"
|
||||
pd_comm_port: "2334"
|
@@ -10,4 +10,4 @@ splitwise_role: decode
|
||||
engine_worker_queue_port: 6678
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7671,7672,7673,7674"
|
||||
pd_comm_port: "2334"
|
||||
pd_comm_port: "2334"
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
||||
engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
pd_comm_port: "2333"
|
@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
|
||||
pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
||||
engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
pd_comm_port: "2333"
|
@@ -3,4 +3,3 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wfp8afp8
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -2,5 +2,4 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,5 +3,4 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
enable_static_graph_inference: True
|
||||
|
@@ -3,4 +3,4 @@ max_num_seqs: 75
|
||||
gpu_memory_utilization: 0.85
|
||||
kv_cache_ratio: 0.75
|
||||
quantization: wint4
|
||||
tensor_parallel_size: 4
|
||||
tensor_parallel_size: 4
|
@@ -3,4 +3,4 @@ max_num_seqs: 25
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.75
|
||||
quantization: wint8
|
||||
tensor_parallel_size: 4
|
||||
tensor_parallel_size: 4
|
@@ -1,3 +0,0 @@
|
||||
metadata:
|
||||
min_tokens: 32
|
||||
max_tokens: 33
|
@@ -5,4 +5,4 @@ metadata:
|
||||
max_tokens: 12288
|
||||
repetition_penalty: 1.05
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
presence_penalty: 0
|
@@ -5,4 +5,4 @@ metadata:
|
||||
max_tokens: 12288
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 1.5
|
||||
presence_penalty: 1.5
|
@@ -1,11 +0,0 @@
|
||||
top_p: 1.0
|
||||
temperature: 1.0
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 30721
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
skip_special_tokens: false
|
||||
chat_template_kwargs:
|
||||
enable_thinking: true
|
@@ -3,4 +3,4 @@ max_num_seqs: 64
|
||||
gpu_memory_utilization: 0.9
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
reasoning_parser: ernie-x1
|
||||
reasoning_parser: ernie-x1
|
51
build.sh
51
build.sh
@@ -18,9 +18,6 @@ BUILD_WHEEL=${1:-1}
|
||||
PYTHON_VERSION=${2:-"python"}
|
||||
export python=$PYTHON_VERSION
|
||||
FD_CPU_USE_BF16=${3:-"false"}
|
||||
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
|
||||
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
|
||||
# These will be translated to 90a / 100a in setup_ops.py for specific features.
|
||||
FD_BUILDING_ARCS=${4:-""}
|
||||
|
||||
|
||||
@@ -77,10 +74,8 @@ function copy_ops(){
|
||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||
if [ "$is_rocm" = "True" ]; then
|
||||
DEVICE_TYPE="rocm"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||
echo -e "ROCM ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
@@ -109,23 +104,6 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||
if [ "$if_corex" = "True" ]; then
|
||||
DEVICE_TYPE="iluvatar-gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
|
||||
if [ "$is_gcu" = "True" ]; then
|
||||
DEVICE_TYPE="gcu"
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
|
||||
echo -e "gcu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
@@ -185,24 +163,17 @@ function build_and_install() {
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
|
||||
}
|
||||
|
||||
function version_info() {
|
||||
output_file="fastdeploy/version.txt"
|
||||
fastdeploy_git_commit_id=$(git rev-parse HEAD)
|
||||
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
|
||||
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
|
||||
cuda_version="nvcc-not-installed"
|
||||
if command -v nvcc &> /dev/null; then
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
|
||||
cd $DIST_DIR
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
|
||||
if [ $? -ne 0 ]; then
|
||||
cd ..
|
||||
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
|
||||
exit 1
|
||||
fi
|
||||
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
|
||||
|
||||
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
|
||||
echo "Paddle version: $paddle_version" >> $output_file
|
||||
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
|
||||
echo "CUDA version: $cuda_version" >> $output_file
|
||||
echo "CXX compiler version: $cxx_version" >> $output_file
|
||||
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
|
||||
cd ..
|
||||
}
|
||||
|
||||
function cleanup() {
|
||||
@@ -236,7 +207,6 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
set -e
|
||||
|
||||
init
|
||||
version_info
|
||||
build_and_install_ops
|
||||
build_and_install
|
||||
cleanup
|
||||
@@ -267,7 +237,6 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
else
|
||||
init
|
||||
build_and_install_ops
|
||||
version_info
|
||||
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
fi
|
||||
|
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.synchronize()
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
@@ -636,7 +636,8 @@ index d5cdd01..5237f09 100644
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
@@ -46,8 +46,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -165,8 +165,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
lambda_batch_ids,
|
||||
lambda_tile_ids_per_batch,
|
||||
@@ -202,8 +202,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
@@ -274,8 +274,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -297,8 +297,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -322,8 +322,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -346,8 +346,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -403,8 +403,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
@@ -473,8 +473,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -550,8 +550,8 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
@@ -610,8 +610,8 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
@@ -688,8 +688,8 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -114,7 +114,8 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -404,7 +405,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -476,7 +477,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -773,8 +775,8 @@ void MultiQueryAppendAttention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -880,7 +882,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -937,7 +939,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -972,7 +974,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1007,8 +1009,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1102,7 +1103,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1170,7 +1171,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1206,7 +1207,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1241,8 +1242,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1289,8 +1289,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1352,8 +1352,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -144,7 +144,8 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -503,7 +504,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -600,7 +601,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -960,8 +962,8 @@ void MultiQueryAppendC4Attention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1086,7 +1088,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1149,7 +1151,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1184,7 +1186,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1219,8 +1221,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1332,7 +1333,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1408,7 +1409,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1443,7 +1444,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1478,8 +1479,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1526,8 +1526,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1593,8 +1593,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -151,7 +151,8 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -472,7 +473,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -574,7 +575,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -897,8 +899,8 @@ void MultiQueryAppendC8Attention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1052,7 +1054,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1109,7 +1111,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1144,7 +1146,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1179,8 +1181,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1316,7 +1317,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1386,7 +1387,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1416,7 +1417,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1451,8 +1452,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1499,8 +1499,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1564,8 +1564,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -1852,7 +1852,7 @@ __global__ void merge_multi_chunks_kernel(
|
||||
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
|
||||
const int* __restrict__ seq_lens_q,
|
||||
const int* __restrict__ seq_lens_kv,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ padding_offsets,
|
||||
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
T* __restrict__ out,
|
||||
@@ -1866,7 +1866,8 @@ __global__ void merge_multi_chunks_kernel(
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, hid = threadIdx.y;
|
||||
const int qid = blockIdx.x;
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -2110,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2126,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int bid = blockIdx.x, hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
@@ -2239,8 +2240,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2259,8 +2259,9 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
|
@@ -40,8 +40,8 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -85,8 +85,8 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -130,8 +130,8 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -175,8 +175,8 @@ void CascadeAppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -211,8 +211,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -246,8 +246,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -281,8 +281,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -316,8 +316,8 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -1,236 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "multi_head_latent_attention_kernel.h"
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
T m;
|
||||
T d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_t() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
T other_m,
|
||||
T other_d) {
|
||||
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
|
||||
T m_prev = m, d_prev = d;
|
||||
m = m_prev > other_m ? m_prev : other_m;
|
||||
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
|
||||
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1 + other_o[i] * scale2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
|
||||
struct softmax_state_ts {
|
||||
uint32_t num_tiles_ = num_tiles;
|
||||
AlignedVector<T, vec_size> o[num_tiles];
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_ts() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; i++) {
|
||||
o[tile_id][i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
|
||||
__device__ __forceinline__ void produce_kv(CacheT *smem,
|
||||
CacheT *kv_base_gptr,
|
||||
const int * block_table_smem,
|
||||
const uint32_t seq_offset_gmem,
|
||||
const uint32_t seq_offset_smem,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t tidx,
|
||||
const uint32_t chunk_start,
|
||||
const uint32_t chunk_end) {
|
||||
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
|
||||
if (block_id < 0) {
|
||||
block_id = 0;
|
||||
}
|
||||
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
|
||||
// 8/16 T/int8 each time
|
||||
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
|
||||
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
|
||||
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
|
||||
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
|
||||
seq_offset_gmem < chunk_end
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
|
||||
const CacheT* k_smem,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
const uint32_t gid,
|
||||
const float scale,
|
||||
float *s,
|
||||
softmax_state_ts<vec_size, T, num_tile_v>& st) {
|
||||
const CacheT* smem;
|
||||
AlignedVector<T, vec_size> q_vec;
|
||||
AlignedVector<T, vec_size> k_vec;
|
||||
float m_prev = st.m;
|
||||
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
|
||||
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
if (iter_base + j < iter_bound) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = 0.f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = 0.f;
|
||||
}
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
|
||||
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
|
||||
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
st.m = st.m > s[j] ? st.m : s[j];
|
||||
}
|
||||
|
||||
// T o_scale = hexp(m_prev - st.m);
|
||||
float o_scale = __expf(m_prev - st.m);
|
||||
st.d *= o_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
// s[j] = hexp(s[j] - st.m);
|
||||
s[j] = __expf(s[j] - st.m);
|
||||
st.d += s[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
st.o[tile_id][i] *= o_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_sv(const float *s,
|
||||
const CacheT *base_v_smem,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
softmax_state_ts<vec_size, T, num_tile>& st) {
|
||||
const CacheT* v_smem;
|
||||
AlignedVector<T, vec_size> v_vec;
|
||||
#pragma unroll
|
||||
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
|
||||
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
|
||||
uint32_t tile_id = vid / bdx;
|
||||
#pragma unroll
|
||||
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
|
||||
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,560 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decode_attention_func.cuh"
|
||||
|
||||
#define CHECK(call) \
|
||||
do \
|
||||
{ \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) \
|
||||
{ \
|
||||
printf("CUDA Error:\n"); \
|
||||
printf(" File: %s\n", __FILE__); \
|
||||
printf(" Line %d:\n", __LINE__); \
|
||||
printf(" Error code:%d\n", error_code); \
|
||||
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
}while(0)
|
||||
|
||||
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
|
||||
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const float in_scale,
|
||||
const int num_chunks,
|
||||
const int chunk_size,
|
||||
const int max_seq_len,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int qid = blockIdx.x, hid = blockIdx.y;
|
||||
const int seq_len_q = seq_lens_q[qid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[qid];
|
||||
if (seq_len_kv == 0) return;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
|
||||
return;
|
||||
}
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = cu_seqlens_q[qid];
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
T m;
|
||||
T d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
// merge per ty
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
T m_prev = m;
|
||||
T d_prev = d;
|
||||
const T m_now = multi_m[offset];
|
||||
const T d_now = multi_d[offset];
|
||||
m = m_prev > m_now ? m_prev : m_now;
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma once
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
|
||||
// merge bdy
|
||||
softmax_state_t<vec_size, T> st{};
|
||||
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
|
||||
#pragma once
|
||||
for (int i = 0; i < iter_num; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
AlignedVector<OutT, vec_size> out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = static_cast<OutT>(st.o[i]);
|
||||
}
|
||||
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
|
||||
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
|
||||
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
|
||||
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||
CacheT * __restrict__ cache_v,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
|
||||
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
|
||||
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
|
||||
OutT * __restrict__ out) {
|
||||
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t bid = bidx, gid = threadIdx.y;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
|
||||
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
|
||||
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
|
||||
|
||||
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
|
||||
|
||||
const int *block_table_now = block_table + bid * max_block_num_per_seq;
|
||||
|
||||
const uint32_t num_chunks = gridDim.y;
|
||||
const uint32_t chunk_id = blockIdx.y;
|
||||
const uint32_t q_len = seq_lens_q[bid];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
|
||||
if (kv_len <= 0) {
|
||||
return;
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = cu_seqlens_q[bid];
|
||||
const uint32_t q_write_idx = cu_seqlens_q[bid];
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
|
||||
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
|
||||
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
|
||||
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
using VecT = AlignedVector<T, VEC_SIZE>;
|
||||
VecT q_vec;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
|
||||
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
|
||||
q_vec[i] *= scale;
|
||||
}
|
||||
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
|
||||
}
|
||||
|
||||
|
||||
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
|
||||
uint32_t stage_idx = 0;
|
||||
constexpr int loop_times = DEAL_EACH_TIME / bdy;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
|
||||
const uint32_t k_seq_id = chunk_start + k_seq_offset;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
k_seq_id,
|
||||
k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
|
||||
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
|
||||
float s[DEAL_EACH_TIME];
|
||||
|
||||
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
wait_group<NUM_STAGES - 1>();
|
||||
__syncthreads();
|
||||
// compute qk
|
||||
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
|
||||
cu_q_smem,
|
||||
kv_smem,
|
||||
chunk_start + iter * DEAL_EACH_TIME,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
gid,
|
||||
scale,
|
||||
s,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
// compute sv
|
||||
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
|
||||
s,
|
||||
kv_smem,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = j * bdy + gid;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
|
||||
stage_idx * DEAL_EACH_TIME + k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// normize if not partition_kv
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
|
||||
const uint32_t tile_id = vid / bdx;
|
||||
if (!partition_kv || num_chunk_this_seq == 1) {
|
||||
st.normalize(tile_id);
|
||||
}
|
||||
if (partition_kv && num_chunk_this_seq > 1) {
|
||||
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
|
||||
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
tmp_m[head_idx] = st.m;
|
||||
tmp_d[head_idx] = st.d;
|
||||
} else {
|
||||
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
||||
void MultiQueryDecoderAttention(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
cudaStream_t &stream,
|
||||
const paddle::Tensor &q,
|
||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float rope_scale,
|
||||
const float rope_theta,
|
||||
const float softmax_scale,
|
||||
const float in_scale,
|
||||
paddle::Tensor *out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
constexpr int num_stages = NUM_STAGE;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
|
||||
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
|
||||
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
|
||||
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
|
||||
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
|
||||
|
||||
constexpr int blocky = GROUP_SIZE;
|
||||
const int gridx = bsz;
|
||||
|
||||
constexpr int num_threads = blockx * blocky;
|
||||
|
||||
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
|
||||
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
|
||||
uint32_t cache_smem_bytes = 0;
|
||||
|
||||
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
|
||||
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
|
||||
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
|
||||
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
|
||||
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
|
||||
|
||||
dim3 grids(gridx, num_chunks, kv_num_heads);
|
||||
dim3 blocks(blockx, blocky);
|
||||
if (num_chunks <= 1) {
|
||||
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
|
||||
cache_vec_size, blockx, blocky>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
} else {
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
|
||||
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
|
||||
constexpr int mblockx = HEAD_DIM_V / vec_size;
|
||||
constexpr int bdy = 256 / mblockx;
|
||||
dim3 grids_merge(bsz, num_heads);
|
||||
dim3 blocks_merge(mblockx, bdy);
|
||||
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
in_scale,
|
||||
num_chunks,
|
||||
chunk_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
}
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim_qk = meta_data.head_dims;
|
||||
const auto head_dim_v = meta_data.head_dims_v;
|
||||
const float rope_scale = 0.0;
|
||||
const float rope_theta = 0.0;
|
||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||
|
||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -28,8 +28,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -134,8 +134,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -254,8 +254,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -366,8 +366,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -498,8 +498,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -745,8 +745,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1047,8 +1047,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1346,8 +1346,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1739,8 +1739,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2034,8 +2034,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2362,8 +2362,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2732,8 +2732,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
|
@@ -21,8 +21,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -57,8 +57,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -79,8 +79,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -102,8 +102,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -125,8 +125,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -149,8 +149,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -182,8 +182,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -207,8 +207,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -232,8 +232,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -257,8 +257,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -282,8 +282,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -317,8 +317,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -344,8 +344,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -371,8 +371,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -398,8 +398,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -424,8 +424,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -471,8 +471,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -503,8 +503,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -536,8 +536,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -570,8 +570,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -603,8 +603,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -650,8 +650,8 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -677,8 +677,8 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -703,8 +703,8 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -729,8 +729,8 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -23,8 +23,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -40,4 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out);
|
@@ -23,8 +23,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
@@ -53,7 +52,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -61,7 +61,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / last_dim;
|
||||
const int h_bias = qkv_bias % last_dim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias;
|
||||
@@ -107,8 +107,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
@@ -131,7 +130,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -139,7 +139,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / last_dim;
|
||||
const int h_bias = qkv_bias % last_dim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t base_idx = token_idx * 3 * hidden_size +
|
||||
@@ -167,8 +167,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
@@ -200,7 +199,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -208,7 +208,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / half_lastdim;
|
||||
const int h_bias = qkv_bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int bias_idx_left =
|
||||
@@ -261,8 +261,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
@@ -286,7 +285,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
@@ -294,7 +294,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int hi = qkv_bias / half_lastdim;
|
||||
const int h_bias = qkv_bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int base_idx_left = token_idx * 3 * full_hidden_size +
|
||||
@@ -327,8 +327,7 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, q_num_head, dim_head]
|
||||
@@ -358,13 +357,14 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
@@ -410,8 +410,7 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
@@ -435,13 +434,14 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];;
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t base_idx =
|
||||
@@ -472,8 +472,7 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const float *qkv_out_scales,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const T *qkv_biases,
|
||||
@@ -505,13 +504,15 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
int ori_seq_id;
|
||||
ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
@@ -560,8 +561,7 @@ template <typename T, int VecSize = 1>
|
||||
__global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const T *qkv_biases,
|
||||
@@ -590,13 +590,15 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
int ori_seq_id;
|
||||
ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
@@ -643,8 +645,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const int *qkv,
|
||||
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales, // [3, q_num_head, dim_head]
|
||||
@@ -675,13 +676,14 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / half_lastdim;
|
||||
const int h_bias = bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int bias_idx_left = hi * last_dim + h_bias;
|
||||
@@ -734,8 +736,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
@@ -760,13 +761,14 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / half_lastdim;
|
||||
const int h_bias = bias % half_lastdim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
const int base_idx_left =
|
||||
@@ -803,8 +805,7 @@ __global__ void cache_kernel(
|
||||
T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size]
|
||||
const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int *__restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int *__restrict__ cu_seqlens_q, // [bsz]
|
||||
const int *__restrict__ padding_offsets, // [num_tokens]
|
||||
const int *__restrict__ seq_lens, // [bsz]
|
||||
const int *__restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
@@ -830,9 +831,11 @@ __global__ void cache_kernel(
|
||||
const uint32_t qkv_bias = bias % hidden_size;
|
||||
const uint32_t hi = qkv_bias / head_size;
|
||||
const uint32_t h_bias = qkv_bias % head_size;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int32_t *block_table_now = nullptr;
|
||||
|
||||
@@ -875,8 +878,8 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -906,46 +909,15 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// int lane_id = wid * 32 + tid;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
@@ -1008,6 +980,7 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
@@ -1145,8 +1118,8 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -1175,46 +1148,10 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
|
||||
const uint32_t HEAD_DIM_HALF = HEAD_DIM / 2;
|
||||
const uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2;
|
||||
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM_HALF / KV_VEC_SIZE; // 4
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; // 8
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM_HALF +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM_HALF]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE_HALF / KV_VEC_SIZE; // 2
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v; // 16
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE_HALF +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE_HALF]);
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T k_scale_smem[HEAD_DIM];
|
||||
@@ -1325,6 +1262,7 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t write_b_stride = HEAD_DIM / 2;
|
||||
@@ -1469,8 +1407,7 @@ void rotary_qk_variable(
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
const T *qkv_bias,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
@@ -1502,8 +1439,7 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1519,8 +1455,7 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
@@ -1538,8 +1473,7 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1555,8 +1489,7 @@ void rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
@@ -1575,8 +1508,7 @@ void gqa_rotary_qk_variable(
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
const T *qkv_bias,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
@@ -1611,8 +1543,7 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1630,8 +1561,7 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
@@ -1651,8 +1581,7 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const int *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1669,8 +1598,7 @@ void gqa_rotary_qk_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
@@ -1694,8 +1622,7 @@ void gqa_rotary_qk_quant_variable(
|
||||
const T *cache_k_scales,
|
||||
const T *cache_v_scales,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
@@ -1727,8 +1654,7 @@ void gqa_rotary_qk_quant_variable(
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_bias,
|
||||
@@ -1747,8 +1673,7 @@ void gqa_rotary_qk_quant_variable(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_bias,
|
||||
@@ -1774,8 +1699,7 @@ void CascadeAppendWriteCacheKVQKV(
|
||||
&qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const int max_seq_len,
|
||||
@@ -1801,8 +1725,7 @@ void CascadeAppendWriteCacheKVQKV(
|
||||
reinterpret_cast<T *>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<T *>(value_cache_out->data<T>()),
|
||||
block_table.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -1826,8 +1749,8 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
const paddle::Tensor &cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1891,8 +1814,8 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -1914,8 +1837,8 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
const paddle::Tensor &cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1961,8 +1884,8 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
|
@@ -25,8 +25,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
@@ -63,8 +63,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
@@ -83,8 +82,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
@@ -105,8 +103,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
|
||||
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
@@ -126,8 +123,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
CascadeAppendWriteCacheKVQKV<T>(meta_data,
|
||||
*qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
max_seq_len,
|
||||
@@ -146,8 +142,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cache_v_scale.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
@@ -173,8 +169,8 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
|
@@ -194,12 +194,12 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
int bsz = cum_offsets.shape()[0];
|
||||
auto max_len_tensor =
|
||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
@@ -335,7 +335,8 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_this_time_dtype) {
|
||||
const paddle::DataType &seq_lens_this_time_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype) {
|
||||
return {
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
@@ -346,7 +347,8 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape) {
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape) {
|
||||
std::vector<int64_t> dynamic_shape = {-1};
|
||||
|
||||
return {dynamic_shape,
|
||||
@@ -363,7 +365,8 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time",
|
||||
"cum_offsets"})
|
||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks"),
|
||||
|
@@ -16,6 +16,7 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "encoder_write_cache_with_rope_impl.cuh"
|
||||
#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
@@ -24,8 +25,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_k,
|
||||
@@ -52,13 +52,14 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const int ori_bi = ori_token_idx / seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
@@ -107,10 +108,9 @@ void gqa_rotary_qk_split_variable(
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *batch_id_per_token,
|
||||
const int *padding_offsets,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_q,
|
||||
const int *cu_seqlens_k,
|
||||
const int token_num,
|
||||
const int num_heads,
|
||||
@@ -133,8 +133,7 @@ void gqa_rotary_qk_split_variable(
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
@@ -149,188 +148,13 @@ void gqa_rotary_qk_split_variable(
|
||||
dim_head);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS=4>
|
||||
__global__ void append_cache_kv_c16(
|
||||
const T *__restrict__ cache_k,
|
||||
const T *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
T *__restrict__ v_out,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ cu_seqlens_k,
|
||||
const int *__restrict__ block_tables,
|
||||
const int *batch_ids,
|
||||
const int *tile_ids_per_batch,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads) {
|
||||
// start_kv_idx: start kv_idx current block
|
||||
// batch_id:block's batch_id
|
||||
// TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8)
|
||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
|
||||
const uint32_t batch_id = batch_ids[tile_idx];
|
||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
|
||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
|
||||
if (seq_lens_this_time[batch_id] <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq;
|
||||
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE];
|
||||
// cache_kv idx
|
||||
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
uint32_t block_stride = kv_num_heads * kv_h_stride;
|
||||
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
|
||||
uint32_t kv_frag[4];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(kv_frag);
|
||||
|
||||
constexpr uint32_t num_vecs_per_head =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
smem_t k_smem(smem);
|
||||
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 16;
|
||||
k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
/***
|
||||
r0c0,r0c1, r0c8,r0c9
|
||||
r8c0,r8c1, r8c8,r8c9
|
||||
***/
|
||||
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
k_tile_ptr0[0] = frag_dq_T[0];
|
||||
k_tile_ptr0[1] = frag_dq_T[1];
|
||||
k_tile_ptr0[8] = frag_dq_T[2];
|
||||
k_tile_ptr0[9] = frag_dq_T[3];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
k_tile_ptr1[0] = frag_dq_T[4];
|
||||
k_tile_ptr1[1] = frag_dq_T[5];
|
||||
k_tile_ptr1[8] = frag_dq_T[6];
|
||||
k_tile_ptr1[9] = frag_dq_T[7];
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>(
|
||||
k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16;
|
||||
}
|
||||
|
||||
// ================v================
|
||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
|
||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy);
|
||||
v_read_idx += 8 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 16;
|
||||
v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
/***
|
||||
r0c0,r0c1, r0c8,r0c9
|
||||
r8c0,r8c1, r8c8,r8c9
|
||||
***/
|
||||
T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride;
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
v_tile_ptr0[0] = frag_dq_T[0];
|
||||
v_tile_ptr0[1] = frag_dq_T[1];
|
||||
v_tile_ptr0[8] = frag_dq_T[2];
|
||||
v_tile_ptr0[9] = frag_dq_T[3];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
v_tile_ptr1[0] = frag_dq_T[4];
|
||||
v_tile_ptr1[1] = frag_dq_T[5];
|
||||
v_tile_ptr1[8] = frag_dq_T[6];
|
||||
v_tile_ptr1[9] = frag_dq_T[7];
|
||||
}
|
||||
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>(
|
||||
v_smem_offset_r, fy);
|
||||
}
|
||||
v_smem_offset_r =
|
||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS=4,
|
||||
bool IS_FP8=false>
|
||||
__global__ void append_cache_kv_c8(
|
||||
__global__ void append_dequant_cache_kv_c8(
|
||||
const CacheT *__restrict__ cache_k,
|
||||
const CacheT *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
@@ -345,16 +169,16 @@ __global__ void append_cache_kv_c8(
|
||||
const int *tile_ids_per_batch,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads) {
|
||||
// start_kv_idx: start kv_idx current block
|
||||
// batch_id:block's batch_id
|
||||
// TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8)
|
||||
// start_kv_idx: 每个block的起始kv_idx
|
||||
// batch_id:每个block属于的batch
|
||||
// TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8)
|
||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
|
||||
const uint32_t batch_id = batch_ids[tile_idx];
|
||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
|
||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
|
||||
if (seq_lens_this_time[batch_id] <= 0) {
|
||||
if (seq_lens_this_time <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -368,8 +192,8 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针
|
||||
|
||||
uint32_t k_frag[4], v_frag[4], frag_dq[4];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(frag_dq);
|
||||
@@ -390,13 +214,13 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
|
||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load v_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||
// load k_smem 行是64 列是128
|
||||
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行
|
||||
for (int fy = 0; fy < 1; fy++) { // 一次8个128b = 128个uint8
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -411,13 +235,13 @@ __global__ void append_cache_kv_c8(
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal k_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
// deal k_smem 行是64 列是128
|
||||
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||
for (int fy = 0; fy < 4; fy++) { // 1次2个128b(32个uint8),4次循环8个128b(128个uint8)
|
||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
// layout
|
||||
// 反量化 存储
|
||||
/***
|
||||
r0c0,r0c1,r0c8,r0c9, r8c0,r8c1,r8c8,r8c9
|
||||
r0c16,r0c17,r0c24,r0c25, r8c16,r8c17,r8c24,r8c25
|
||||
@@ -427,7 +251,8 @@ __global__ void append_cache_kv_c8(
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
convert_c8<T,IS_FP8>(frag_dq_T,k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
|
||||
convert_c8<T,IS_FP8>(frag_dq_T,k_frag[2 * i]); // 4个uint8/fp8 -> 4个T
|
||||
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale;
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale;
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale;
|
||||
@@ -435,7 +260,8 @@ __global__ void append_cache_kv_c8(
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4,k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4,k_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
||||
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale;
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale;
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale;
|
||||
@@ -449,8 +275,8 @@ __global__ void append_cache_kv_c8(
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 8;
|
||||
}
|
||||
|
||||
// ================v================
|
||||
|
||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
|
||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
|
||||
wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp
|
||||
@@ -460,9 +286,9 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
uint32_t v_read_idx = (wid * 8 + tid / 4) * BLOCK_SIZE +
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||
// load v_smem 行是128 列是64
|
||||
for (int fy = 0; fy < 4; fy++) { // 每个warp1次8行,循环4次32行,4个warp128行
|
||||
for (int fz = 0; fz < 1; fz++) { // 一次4个128b = 64个uint8
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -478,32 +304,42 @@ __global__ void append_cache_kv_c8(
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
// deal v_smem 行是128 列是64 row_idx是head_dim, col_idx是block_size
|
||||
for (int fy = 0; fy < 2; fy++) { // 每个warp1次16行,循环2次32行,4个warp128行
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||
for (int fz = 0; fz < 2; fz++) { // 1次2个128b(32个uint8),2次循环4个128b(64个uint8)
|
||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
// 反量化 存储
|
||||
for (int i = 0; i < 4 / 2; i++) {
|
||||
T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8;
|
||||
convert_c8<T,IS_FP8>(frag_dq_T, v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T
|
||||
if (kv_idx < end_idx) {
|
||||
convert_c8<T,IS_FP8>(frag_dq_T, v_frag[2 * i]); // 4个uint8/fp8 -> 4个T
|
||||
#ifdef C8_DEBUG
|
||||
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
||||
printf("1.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n",
|
||||
fy, fz, kv_idx, dim_idx, static_cast<float>(frag_dq_T[0]), static_cast<float>(frag_dq_T[1]),
|
||||
static_cast<float>(frag_dq_T[2]), static_cast<float>(frag_dq_T[3]));
|
||||
}
|
||||
#endif
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale;
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale;
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale;
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale;
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale;
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
||||
|
||||
|
||||
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
||||
#ifdef C8_DEBUG
|
||||
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
||||
printf("2.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n",
|
||||
fy, fz, kv_idx, dim_idx + 8, static_cast<float>(frag_dq_T[4]), static_cast<float>(frag_dq_T[5]),
|
||||
static_cast<float>(frag_dq_T[6]), static_cast<float>(frag_dq_T[7]));
|
||||
}
|
||||
#endif
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale;
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale;
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale;
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale;
|
||||
}
|
||||
kv_idx += 16;
|
||||
@@ -516,250 +352,12 @@ __global__ void append_cache_kv_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS=4>
|
||||
__global__ void append_cache_kv_c4(
|
||||
const CacheT *__restrict__ cache_k,
|
||||
const CacheT *__restrict__ cache_v,
|
||||
T *__restrict__ k_out,
|
||||
T *__restrict__ v_out,
|
||||
const T *__restrict__ cache_k_dequant_scales,
|
||||
const T *__restrict__ cache_v_dequant_scales,
|
||||
const T *__restrict__ cache_k_zero_point,
|
||||
const T *__restrict__ cache_v_zero_point,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ cu_seqlens_k,
|
||||
const int *__restrict__ block_tables,
|
||||
const int *batch_ids,
|
||||
const int *tile_ids_per_batch,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads) {
|
||||
// start_kv_idx: start kv_idx current block
|
||||
// batch_id:block's batch_id
|
||||
// TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8)
|
||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
|
||||
const uint32_t batch_id = batch_ids[tile_idx];
|
||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
|
||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
|
||||
if (seq_lens_this_time[batch_id] <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq;
|
||||
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE];
|
||||
if (block_id < 0) block_id = 0;
|
||||
|
||||
constexpr uint32_t HEAD_DIM_HALF = HEAD_DIM / 2;
|
||||
constexpr uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2;
|
||||
// cache_kv idx
|
||||
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM_HALF;
|
||||
uint32_t block_stride = kv_num_heads * kv_h_stride;
|
||||
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
|
||||
|
||||
// k_out v_out idx
|
||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
|
||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
|
||||
uint32_t k_frag[4], v_frag[4], frag_dq[8];
|
||||
T *frag_dq_T = reinterpret_cast<T *>(frag_dq);
|
||||
|
||||
// load dequant scales and zero points
|
||||
const T *cache_k_scale_now = cache_k_dequant_scales + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_now = cache_v_dequant_scales + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM;
|
||||
T *cache_k_scale_smem = reinterpret_cast<T *>(
|
||||
smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
|
||||
T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM;
|
||||
T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM;
|
||||
T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) {
|
||||
cache_k_scale_smem[i] = cache_k_scale_now[i];
|
||||
cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast<T>(136.f);
|
||||
cache_v_scale_smem[i] = cache_v_scale_now[i];
|
||||
cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast<T>(136.f);
|
||||
}
|
||||
|
||||
smem_t k_smem(smem);
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM_HALF / num_elems_per_128b<CacheT>(); // 2
|
||||
constexpr uint32_t num_vecs_per_blocksize =
|
||||
BLOCK_SIZE_HALF / num_elems_per_128b<CacheT>();
|
||||
constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; // 4
|
||||
constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize;
|
||||
|
||||
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||
wid * 8 + tid / 4, tid % 4); // 2(iter) * 4(warp) * 8 row per warp
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); //
|
||||
|
||||
uint32_t k_read_idx = (wid * 8 + tid / 4) * HEAD_DIM / 2 +
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<4, num_vecs_per_head_k>(k_smem_offset_w, fy);
|
||||
k_read_idx += 4 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
k_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_w) - 4;
|
||||
k_read_idx += 8 * NUM_WARPS * HEAD_DIM / 2 - 4 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
|
||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
|
||||
convert_int4(frag_dq_T, k_frag[2 * i]);
|
||||
convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]);
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
|
||||
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
|
||||
k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
|
||||
k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
|
||||
k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
|
||||
k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
|
||||
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
|
||||
k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
|
||||
k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
|
||||
k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
|
||||
k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
|
||||
}
|
||||
col_idx += 32;
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>(
|
||||
k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 4;
|
||||
}
|
||||
|
||||
// ================v================
|
||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT) / 2);
|
||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
|
||||
wid * 16 + tid / 2, tid % 2); // 4 * 8 per warp
|
||||
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
|
||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
uint32_t v_read_idx = (wid * 16 + tid / 2) * BLOCK_SIZE_HALF +
|
||||
tid % 2 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 rows
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>(v_smem_offset_w, fz);
|
||||
v_read_idx += 2 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
v_smem_offset_w =
|
||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_w) - 2;
|
||||
v_read_idx += 16 * NUM_WARPS * BLOCK_SIZE_HALF - 2 * num_elems_per_128b<CacheT>();
|
||||
}
|
||||
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx;
|
||||
T *v_tile_ptr1 = v_tile_ptr0 + 8;
|
||||
|
||||
convert_int4(frag_dq_T, v_frag[2 * i]);
|
||||
convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]);
|
||||
if (kv_idx < end_idx) {
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 16 < end_idx) {
|
||||
v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 17 < end_idx) {
|
||||
v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 24 < end_idx) {
|
||||
v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 25 < end_idx) {
|
||||
v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
kv_idx += 32;
|
||||
}
|
||||
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
|
||||
v_smem_offset_r, fz);
|
||||
}
|
||||
v_smem_offset_r =
|
||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_r) - 2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
void AppendCacheKV(
|
||||
void AppendDequantCache(
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::Tensor &cache_k_dequant_scales,
|
||||
const paddle::Tensor &cache_v_dequant_scales,
|
||||
const paddle::Tensor &cache_k_zp,
|
||||
const paddle::Tensor &cache_v_zp,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &cu_seqlens_k,
|
||||
@@ -773,41 +371,19 @@ void AppendCacheKV(
|
||||
paddle::Tensor *k_out,
|
||||
paddle::Tensor *v_out,
|
||||
const cudaStream_t& stream
|
||||
) {
|
||||
) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
constexpr int NUM_WARPS = 4;
|
||||
int block_num = cache_num_blocks_x.data<int>()[0];
|
||||
dim3 grids(block_num, 1, kv_num_heads);
|
||||
dim3 blocks(32, NUM_WARPS);
|
||||
if (cache_quant_type == "none") {
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2;
|
||||
auto kernel_func = append_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel_func,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
}
|
||||
kernel_func<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
block_tables.data<int>(),
|
||||
cache_batch_ids.data<int>(),
|
||||
cache_tile_ids_per_batch.data<int>(),
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads
|
||||
);
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
constexpr int NUM_WARPS = 4;
|
||||
int block_num = cache_num_blocks_x.data<int>()[0];
|
||||
dim3 grids(block_num, 1, kv_num_heads);
|
||||
dim3 blocks(32, NUM_WARPS);
|
||||
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
|
||||
auto kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
||||
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
|
||||
kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel_func,
|
||||
@@ -830,34 +406,6 @@ void AppendCacheKV(
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads
|
||||
);
|
||||
} else if (cache_quant_type == "cache_int4_zp") {
|
||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T);
|
||||
|
||||
auto kernel_func = append_cache_kv_c4<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel_func,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
}
|
||||
kernel_func<<<grids, blocks, smem_size, stream>>>(
|
||||
cache_k.data<uint8_t>(),
|
||||
cache_v.data<uint8_t>(),
|
||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_dequant_scales.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_dequant_scales.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_zp.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_zp.data<T>())),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
block_tables.data<int>(),
|
||||
cache_batch_ids.data<int>(),
|
||||
cache_tile_ids_per_batch.data<int>(),
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads
|
||||
);
|
||||
} else {
|
||||
PADDLE_THROW("%s mode isn't implemented yet", cache_quant_type.c_str());
|
||||
}
|
||||
@@ -873,7 +421,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids,
|
||||
@@ -901,9 +450,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int token_num = qkv_dims[0];
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
const int block_size = key_cache.dims()[2];
|
||||
const int batch_size = seq_lens_this_time.dims()[0];
|
||||
const int batch_size = cum_offsets.dims()[0];
|
||||
const int kv_num_heads = key_cache_dims[1];
|
||||
const int head_dim = cache_quant_type == "cache_int4_zp" ? key_cache_dims[3] * 2 : key_cache_dims[3];
|
||||
const int head_dim = key_cache_dims[3];
|
||||
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
|
||||
const float softmax_scale = 1.f / sqrt(head_dim);
|
||||
|
||||
@@ -914,7 +463,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
meta_data.q_num_heads = num_heads;
|
||||
meta_data.max_blocks_per_seq = max_blocks_per_seq;
|
||||
meta_data.block_size = block_size;
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
|
||||
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
|
||||
|
||||
@@ -944,10 +493,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
@@ -962,8 +510,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
meta_data,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
max_seq_len,
|
||||
@@ -980,8 +527,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
cache_v_quant_scales.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
@@ -992,32 +539,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int4_zp") {
|
||||
CascadeAppendWriteCacheKVC4QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
*const_cast<paddle::Tensor*>(&value_cache),
|
||||
qkv_out,
|
||||
cache_k_quant_scales.get(),
|
||||
cache_v_quant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
@@ -1040,13 +561,11 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
}
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(
|
||||
AppendDequantCache<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
@@ -1075,7 +594,8 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"batch_id_per_token",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
|
@@ -1,292 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mla_cache_kernel.cuh"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
|
||||
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
} else {
|
||||
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = seq_lens_encoder.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_STATIC_OP(decode_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
|
@@ -1,240 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,38 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &batch_id_per_token,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -26,8 +26,8 @@ __global__ void append_clear_cache_int8_block(
|
||||
// block_size, head_size // 2]
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -41,10 +41,10 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -100,8 +100,8 @@ __global__ void append_clear_cache_int4_block(
|
||||
// block_size, head_size // 2]
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -115,10 +115,10 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -178,8 +178,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -214,12 +214,12 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
cum_offsets[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -311,8 +311,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -347,12 +347,12 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / half_hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
cum_offsets[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -458,8 +458,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -484,10 +484,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -690,8 +690,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -716,10 +716,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
|
||||
@@ -1068,8 +1068,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1097,10 +1097,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int lane_id = tid % 32;
|
||||
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -1130,10 +1130,6 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
LoadOutScaleT out_scale_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
#pragma unroll
|
||||
for (int v_i = 0; v_i < VecSize; v_i++) {
|
||||
bias_vec[v_i] = 0;
|
||||
}
|
||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
||||
#pragma unroll
|
||||
@@ -1141,8 +1137,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
@@ -1152,10 +1148,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
// input_left = input_left * out_scale_vec[2 * i] +
|
||||
// static_cast<float>(bias_vec[2 * i]);
|
||||
// input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||
// static_cast<float>(bias_vec[2 * i + 1]);
|
||||
input_left = input_left * out_scale_vec[2 * i] +
|
||||
static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||
static_cast<float>(bias_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
@@ -1171,35 +1167,6 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
||||
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||
block_size * half_head_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||
HeadDim * half_block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
@@ -1215,11 +1182,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
LoadScaleT zp_vec1, zp_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
#pragma unroll
|
||||
for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) {
|
||||
bias_vec1[v_i] = 0;
|
||||
bias_vec2[v_i] = 0;
|
||||
}
|
||||
|
||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
//////////
|
||||
@@ -1228,11 +1191,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
/////
|
||||
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||
// &out_scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||
&out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
@@ -1252,10 +1215,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
// input_left =
|
||||
// input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||
// input_right =
|
||||
// input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||
input_left =
|
||||
input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||
input_right =
|
||||
input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
@@ -1270,10 +1233,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
// input_left =
|
||||
// input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||
// input_right =
|
||||
// input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||
input_left =
|
||||
input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||
input_right =
|
||||
input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
@@ -1411,8 +1374,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1440,10 +1403,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int lane_id = tid % 32;
|
||||
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -1829,4 +1792,4 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -22,8 +22,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -59,8 +59,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -82,8 +82,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -106,8 +106,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -136,8 +136,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -151,8 +151,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -175,8 +175,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -201,8 +201,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -233,8 +233,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -248,8 +248,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -274,8 +274,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -301,8 +301,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -349,8 +349,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -376,8 +376,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -409,8 +409,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -442,8 +442,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -488,8 +488,8 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -514,8 +514,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -539,8 +539,8 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -566,8 +566,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -582,4 +582,4 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out);
|
@@ -23,8 +23,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
// gqa_group_size, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -39,4 +39,4 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out);
|
@@ -37,8 +37,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,8 +37,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -38,8 +38,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -85,8 +85,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -80,8 +80,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,8 +82,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,8 +82,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -81,8 +81,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -36,8 +36,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -81,8 +81,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_q,
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -22,8 +22,8 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -21,8 +21,8 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -21,8 +21,8 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -21,8 +21,8 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -25,7 +25,6 @@ struct AppendAttnMetaData {
|
||||
int kv_num_heads;
|
||||
int token_nums;
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
};
|
||||
|
||||
@@ -310,56 +309,10 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 576: { \
|
||||
constexpr size_t HEAD_DIM = 576; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_stage: ", num_stage); \
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
||||
@@ -375,13 +328,10 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
||||
constexpr size_t cache_bytes = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -437,20 +387,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
if (block_shape_q <= 16) { \
|
||||
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
||||
|
@@ -30,4 +30,4 @@ inline int getSMVersion()
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user