Compare commits

..

54 Commits

Author SHA1 Message Date
lizhenyun01
d40a1046de [Feature] support rl_tp_degree (#3934)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Feature] support rl_tp_degree

* add rl_tp_degree in lmhead

* add rl_tp_degree in bias

* fix split_axis=0 in bias

* fix split_axis in weight

* fix bias rl_tp_degree

* fix bias rl_tp_degree

* change attr to dict

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-08 16:20:32 +08:00
Sunny-bot1
fa2369271d update env docs for Machete (#3960) 2025-09-08 14:44:52 +08:00
Zhang Yulong
8903f937f9 update ci (#3953) 2025-09-08 14:21:25 +08:00
luukunn
1023a67765 [BugFix] fix default parser (#3932)
* add reasoning parser plugin

* fix finish reason

* fix default parser

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-08 14:12:13 +08:00
Zero Rains
d43549953c [Cherry-Pick][Bug Fix]fix the bug for real size 0 in cudagraph (#3888)
* fix the bug for real size 0 in cudagraph

* fix cache_messager

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-08 14:06:10 +08:00
Yuanle Liu
c7c1627456 Update paddleformers version to >=0.2.3 (#3936)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* Update paddleformers version to 0.2.2

* Update requirements.txt

* Update paddleformers version to >=0.2.3
2025-09-08 11:11:05 +08:00
ming1753
d6bf6de5e6 [Bug Fix] Fix mm performance degradation (#3942)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Bug Fix] Fix mm performance degradation

* formate

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: chenjian <1435317881@qq.com>
2025-09-08 00:32:22 +08:00
chenjian
38e734e183 [Feature] support hierarchical cache in v1 (#3939) 2025-09-08 00:31:34 +08:00
bukejiyu
051e4a881c ignore (#3949) 2025-09-07 23:57:48 +08:00
chenjian
b2bb37d7c0 [Fix] when prompt token ids is numpy (#3944) 2025-09-07 23:02:03 +08:00
CSWYF3634076
c6e2a37a95 [BugFix] qwen2.5vl enable_thinking=true bug fix (#3920) 2025-09-07 21:06:36 +08:00
chenjian
8d77c1cb51 [Optimize] optimize prefix cache in release22 (#3889)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* optimize prefix cache in release22

* optimize prefix cache in release22

* fix worker

* fix

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-06 09:52:01 +08:00
chenjian
41cd3e24c9 [Feature] Enable prefix caching as default (#3816)
* [Feature] Enable prefix caching as default

* [Feature] Enable prefix caching as default

* Set prefix caching as default

* skip dynamic load

* fix kill bug

* fix kill bug

* fix kill bug

* fix ci

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-06 09:51:34 +08:00
Zhang Yulong
11b18e5ef0 add cache queue port (#3904) (#3926)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* add cache queue port

* add cache queue port

* add cache queue port
2025-09-06 00:00:12 +08:00
freeliuzc
e2c764fd5a update hybrid-mtp-with-ngram (#3924) 2025-09-05 23:06:57 +08:00
lizhenyun01
2d975e16b0 [BugFix] fix TaskQueue dp_id in multi node (#3919) 2025-09-05 22:29:26 +08:00
chenjian
8915c8411d Revert "[Feature] Setting number of apiserver workers automatically (#3794)" (#3918)
This reverts commit d1d063e4af.
2025-09-05 21:06:50 +08:00
yinwei
77c1bd0813 [XPU]Fixed the issue of performance degradation caused by enabling ENABLE_V1_KVCACHE_SCHEDULER (#3900)
* fix bug

* fix bug

* update

* udpate

* update
2025-09-05 19:17:25 +08:00
Yuanle Liu
473cde779f paddleformers==0.2.1 (#3925) 2025-09-05 19:06:15 +08:00
chen
335d1c8e8f 【CP】Compatible with EB 0.3B torch model arch (#3914)
* fix

* check
2025-09-05 19:05:07 +08:00
ltd0924
173e4df982 [Fix] mv connection_manager init (#3902)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* Update serving_chat.py

* Update serving_completion.py

* Update serving_completion.py

* mv connection_manager init

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
2025-09-05 17:42:36 +08:00
lizhenyun01
199f88ce1e support tpep weight load (#3882) 2025-09-05 13:56:29 +08:00
ltd0924
55ebe855c0 [Feature] support controller port in multi api server (#3895)
* fix scheduler bug

* fix

* Update api_server.py

* Update multi_api_server.py
2025-09-05 13:38:58 +08:00
zhouchong
deb7ad205f fix qwen_vl_processor miss image_patch_id (#3894)
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-05 11:32:34 +08:00
Yuanle Liu
e9f72df918 paddleformers==0.1.4 (#3908) 2025-09-05 11:25:57 +08:00
chenjian
8567ada09e [Fix] disable scheduler v1 in guided decoding (#3877)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* disable scheduler v1 in guided decoding

* disable scheduler v1 in guided decoding
2025-09-04 20:54:55 +08:00
YuBaoku
afcde19277 [CI] update paddleformers==0.2 in release/2.2 (#3828)
* [DEBUG] Adapt validation for paddleformers==0.2 in release/2.2

* [CI] update paddleformers==0.2 in release/2.2
2025-09-04 20:12:37 +08:00
lizhenyun01
d40d3a5a4f fix DP&&TP (#3872)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
2025-09-04 14:38:26 +08:00
luukunn
b8d0f1c081 [bug] fix finish reason (#3858)
* add reasoning parser plugin

* fix finish reason

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
2025-09-04 14:36:03 +08:00
ltd0924
8550e19008 [bugfix] scheduler (#3871)
* fix scheduler bug

* fix

* Update api_server.py
2025-09-04 11:34:12 +08:00
chenjian
a0c03510c0 [Bug fix] Fix prompt token ids dtype in v1 (#3861) 2025-09-04 11:02:37 +08:00
chenjian
fb1e0d6a87 [Feature] Set scheduler v1 as default (#3812)
* [Feature] Set scheduler v1 as default

* [Feature] Set scheduler v1 as default

* [Feature] Set scheduler v1 as default

* [Feature] Set scheduler v1 as default

* [Feature] Set scheduler v1 as default

* [Feature] Set scheduler v1 as default
2025-09-04 11:02:10 +08:00
gaoziyuan
fbf0e9d2aa fix mem boom in ep (#3852) 2025-09-04 10:38:34 +08:00
SunLei
8c0e7d6fe9 Support for async processor added. (#3870)
* Support for async processor added.

* remove yappi code
2025-09-04 10:35:08 +08:00
yangjianfengo1
b56b015d85 fix port (#3865)
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
2025-09-04 10:02:08 +08:00
ming1753
1432e336d7 [Bug Fix] Fix bug of multimodal inputs only text (#3850)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
2025-09-03 19:48:10 +08:00
yangjianfengo1
9213a58a06 【Fix bug] w4afp8 的nblock固定为256,并且fa3的append attn 增加mask参数 (#3771) (#3835)
* fix w4afp8

* 增加集中式配置

* codestyle

* fix fa3 append attn
2025-09-03 19:36:45 +08:00
plusNew001
87ef0f5d30 [XPU] Update XPU stable xvllm and xtdk version for 2.2 & Change CI Case (#3855)
* Update no_proxy environment variable in CI workflow

* Install lsof and kill api_server processes

Install lsof tool and kill processes using it.

* Update dependency versions for stable release

* Update CI script to use stable dependencies
2025-09-03 19:33:06 +08:00
plusNew001
abcd2148c0 [XPU]Update XPU CI Case (#3844)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* Update no_proxy environment variable in CI workflow

* Install lsof and kill api_server processes

Install lsof tool and kill processes using it.
2025-09-03 15:29:47 +08:00
gaoziyuan
05b6591c23 【BugFix】add moe noaux_tc tatics in trition backend (#3821)
* add moe noaux_tc tatics in trition backend

* fix

* add dp config
2025-09-03 13:28:44 +08:00
plusNew001
42402c80e9 Update installation method for paddlepaddle-xpu (#3834) 2025-09-03 11:28:27 +08:00
luukunn
1968c65849 add reasoning parser plugin (#3820) 2025-09-03 11:17:13 +08:00
ltd0924
37cb37b7f2 [BugFix] fix scheduler (#3818)
* fix scheduler bug

* fix
2025-09-03 11:16:49 +08:00
bukejiyu
f975f7de2f [v1loader]Reduce EB300B model loading time (#3700) (#3810)
* speed up eb45

* update
2025-09-03 10:14:31 +08:00
Yuanle Liu
174510180a [BugFix] fix error of import paddle.base.core.Config (#3761) (#3804)
* 延迟 import Config

* support chunked_prefill

* support chunked_prefill
2025-09-03 10:14:03 +08:00
ltd0924
5cda326ba2 Update qwen_vl_processor.py (#3806)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
2025-09-02 21:56:24 +08:00
RAM
a6c8f17431 [Executor] Fix bug of import paddle with RLHF (#3781) (#3817) 2025-09-02 21:42:59 +08:00
ltd0924
cd09384a14 [BugFix] fix max streaming tokens invalid (#3799)
* Update serving_chat.py

* Update serving_completion.py

* Update serving_completion.py
2025-09-02 21:03:13 +08:00
ltd0924
0f42771a84 [Feature] support model weight update in ep (#3802)
* Update config.py

* Update ep.py

* Update fused_moe_backend_base.py

* Update dynamic_weight_manager.py

* Update worker_process.py

* fix ci
2025-09-02 20:52:47 +08:00
Jiang-Jia-Jun
d1d063e4af [Feature] Setting number of apiserver workers automatically (#3794)
Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-09-02 17:19:07 +08:00
kevin
a86b35ab49 Fix chunked prefill (#3778)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* update enable chunked_prefill

* update code

* update code

* update code
2025-09-02 13:41:55 +08:00
YUNSHEN XIE
0cdbc950b5 fix ce compile task upload error (#3788) 2025-09-02 11:52:50 +08:00
YUNSHEN XIE
2b0a745d57 fix ce build job (#3777) 2025-09-02 10:53:26 +08:00
Jiang-Jia-Jun
1953c7c759 Update FASTDEPLOY_VERSION to 2.2.0 2025-08-31 21:31:12 +08:00
485 changed files with 4270 additions and 26893 deletions

View File

@@ -44,7 +44,7 @@ jobs:
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
@@ -160,7 +160,6 @@ jobs:
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \

View File

@@ -44,7 +44,7 @@ jobs:
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
@@ -160,7 +160,6 @@ jobs:
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \

View File

@@ -55,7 +55,7 @@ on:
jobs:
fd-build:
runs-on: [self-hosted, GPU-Build]
timeout-minutes: 360
timeout-minutes: 240
outputs:
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
steps:
@@ -134,7 +134,6 @@ jobs:
fi
git config --global --add safe.directory /workspace/FastDeploy
chown -R $(whoami) /workspace/FastDeploy
cd FastDeploy
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)

View File

@@ -1,73 +0,0 @@
name: Docker Build
description: "FastDeploy CI Image Build"
on:
workflow_call:
inputs:
CI_DOCKER_IMAGE_NAME:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
DOCKER_IMAGE_NAME:
description: "Build Images"
required: false
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
outputs:
docker_name_precheck:
description: "Output path of the generated wheel"
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
jobs:
docker_build:
runs-on: [self-hosted, Docker-Build]
outputs:
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
steps:
- name: Docker Build
id: docker_build
shell: bash
env:
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
# Docker Build
cd tools/dockerfile/
set -e
cp ../../requirements.txt ./
cp ../../scripts/unittest_requirement.txt ./
docker build -t ${docker_image_name} -f Dockerfile.ci . \
--network host \
--no-cache
docker push ${docker_image_name}
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT

View File

@@ -39,7 +39,6 @@ jobs:
docker_image: ${{ inputs.DOCKER_IMAGE }}
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
run: |
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
@@ -117,6 +116,7 @@ jobs:
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
@@ -147,7 +147,6 @@ jobs:
--skip install
cd PaddleTest/framework/ServeTest
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \

View File

@@ -46,7 +46,7 @@ jobs:
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \

View File

@@ -44,7 +44,7 @@ jobs:
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \

View File

@@ -41,7 +41,7 @@ jobs:
run_tests_with_coverage:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 90
timeout-minutes: 60
needs: check_cov_skip
if: needs.check_cov_skip.outputs.can-skip != 'true'
outputs:
@@ -60,7 +60,7 @@ jobs:
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
@@ -171,7 +171,10 @@ jobs:
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install -r scripts/unittest_requirement.txt
python -m pip install coverage
python -m pip install diff-cover
python -m pip install pytest-cov
python -m pip install jsonschema aistudio_sdk==0.3.5
python -m pip install ${fd_wheel_url}
rm -rf fastdeploy
# coverage subprocess use

View File

@@ -9,7 +9,7 @@ on:
permissions: read-all
concurrency:
group: CE-Job-${{ github.ref }}-${{ github.sha }}
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
@@ -199,13 +199,13 @@ jobs:
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
echo "commit wheel url is ${WHEEL_PATH}"
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
echo "commit wheel url is ${WHEEL_PATH}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
ce_upload_sm8689:
@@ -224,9 +224,9 @@ jobs:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
@@ -238,11 +238,11 @@ jobs:
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
echo "commit wheel url is ${WHEEL_PATH}"
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
echo "commit wheel url is ${WHEEL_PATH}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
echo "latest wheel url is ${WHEEL_PATH_LATEST}"

View File

@@ -1,174 +0,0 @@
name: CI Images Build
on:
workflow_dispatch:
schedule:
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
permissions: read-all
concurrency:
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
if [[ "${{ github.ref_type }}" == "tag" ]]; then
commit_id=${{ github.sha }}
tag_name=${{ github.ref_name }}
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
else
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
ci_image_build:
name: CI Images Build
needs: clone
uses: ./.github/workflows/_ci_image_build.yml
with:
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
build_sm8090:
name: BUILD_SM8090
needs: [clone, ci_image_build]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "90"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build_sm8090,ci_image_build]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
publish_pre_check:
name: Publish Docker Images Pre Check
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
runs-on: [self-hosted, Docker-Build]
steps:
- name: Images Uploading
env:
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
run: |
echo "images_name=${images_name}"
docker images ${ci_image_name}
docker tag ${images_name} ${ci_image_name}
docker push ${ci_image_name}

View File

@@ -21,7 +21,7 @@ jobs:
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "90"
COMPILE_ARCH: "89,90"
WITH_NIGHTLY_BUILD: "OFF"
FD_VERSION: "0.0.0"

View File

@@ -13,7 +13,7 @@ on:
permissions: read-all
concurrency:
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
@@ -319,13 +319,3 @@ jobs:
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

10
.gitmodules vendored
View File

@@ -1,10 +0,0 @@
[submodule "custom_ops/third_party/DeepGEMM"]
path = custom_ops/third_party/DeepGEMM
url = https://github.com/deepseek-ai/DeepGEMM.git
ignore = all
[submodule "custom_ops/third_party/cutlass"]
path = custom_ops/third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "custom_ops/third_party/nlohmann_json"]
path = custom_ops/third_party/nlohmann_json
url = https://github.com/nlohmann/json.git

View File

@@ -26,8 +26,6 @@ English | [简体中文](README_CN.md)
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
## News
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -59,9 +57,8 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU and MetaX GPU are currently under development and testing. Stay tuned for updates!
## Get Started
@@ -71,12 +68,20 @@ Learn how to use FastDeploy through our documentation:
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
- [Offline Inference Development](./docs/offline_inference.md)
- [Online Service Deployment](./docs/online_serving/README.md)
- [Full Supported Models List](./docs/supported_models.md)
- [Best Practices](./docs/best_practices/README.md)
## Supported Models
Learn how to download models, enable using the torch format, and more:
- [Full Supported Models List](./docs/supported_models.md)
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
## Advanced Usage

View File

@@ -26,9 +26,7 @@
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
## 最新活动
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容性能进一步优化更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略更多模型支持PD分离和CUDA Graph昆仑、海光等更多硬件支持增强全方面优化服务和推理引擎的性能。
**[2025-08] 🔥 FastDeploy v2.1 全新发布:** 全新的KV Cache调度策略更多模型支持PD分离和CUDA Graph昆仑、海光等更多硬件支持增强全方面优化服务和推理引擎的性能。
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -57,9 +55,8 @@ FastDeploy 支持在**英伟达NVIDIAGPU**、**昆仑芯KunlunxinXPU
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
**注意:** 我们正在积极拓展硬件支持范围。目前包括昇腾AscendNPU 其他硬件平台正在开发测试中。敬请关注更新!
**注意:** 我们正在积极拓展硬件支持范围。目前包括昇腾AscendNPU 和 沐曦MetaXGPU 在内的其他硬件平台正在开发测试中。敬请关注更新!
## 入门指南
@@ -69,12 +66,20 @@ FastDeploy 支持在**英伟达NVIDIAGPU**、**昆仑芯KunlunxinXPU
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
- [离线推理](./docs/zh/offline_inference.md)
- [在线服务](./docs/zh/online_serving/README.md)
- [模型支持列表](./docs/zh/supported_models.md)
- [最佳实践](./docs/zh/best_practices/README.md)
## 支持模型列表
通过我们的文档了解如何下载模型如何支持torch格式等
- [模型支持列表](./docs/zh/supported_models.md)
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
## 进阶用法

View File

@@ -6,4 +6,3 @@ tensor_parallel_size: 8
max_num_batched_tokens: 4096
max_num_partial_prefills: 3
max_long_partial_prefills: 3
quantization: wint4

View File

@@ -1,6 +0,0 @@
tensor_parallel_size: 1
max_model_len: 131072
max_num_seqs: 32
quantization: wint4
max_num_batched_tokens: 8192
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'

View File

@@ -6,4 +6,3 @@ tensor_parallel_size: 8
max_num_batched_tokens: 4096
max_num_partial_prefills: 3
max_long_partial_prefills: 3
quantization: wint8

View File

@@ -1,5 +0,0 @@
max_model_len: 32768
max_num_seqs: 256
kv_cache_ratio: 0.75
tensor_parallel_size: 4
gpu_memory_utilization: 0.9

View File

@@ -13,4 +13,3 @@ pd_comm_port: "2334"
max_num_batched_tokens: 384
max_num_partial_prefills: 3
max_long_partial_prefills: 3
quantization: wint4

View File

@@ -10,4 +10,3 @@ engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333"
quantization: wint4

View File

@@ -1,11 +0,0 @@
enable_mm: True
max_model_len: 131072
max_num_seqs: 56
gpu_memory_utilization: 0.8
kv_cache_ratio: 0.8
tensor_parallel_size: 8
quantization: wint4
limit_mm_per_prompt: '{"image": 100, "video": 100}'
enable_chunked_prefill: True
max_num_batched_tokens: 384
reasoning_parser: ernie-45-vl

View File

@@ -1,7 +1,7 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 36
gpu_memory_utilization: 0.9
gpu_memory_utilization: 0.95
kv_cache_ratio: 0.8
tensor_parallel_size: 8
quantization: wint8

View File

@@ -1,7 +1,7 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 36
gpu_memory_utilization: 0.85
gpu_memory_utilization: 0.8
kv_cache_ratio: 0.8
tensor_parallel_size: 8
quantization: wint8

View File

@@ -1,9 +0,0 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 128
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
reasoning_parser: ernie-45-vl

View File

@@ -1,10 +0,0 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 128
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
quantization: wint4
reasoning_parser: ernie-45-vl

View File

@@ -1,10 +0,0 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 128
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
quantization: wint8
reasoning_parser: ernie-45-vl

View File

@@ -1 +0,0 @@
max_tokens: 131071

View File

@@ -1 +0,0 @@
max_tokens: 12288

View File

@@ -2,7 +2,7 @@ top_p: 0.95
temperature: 0.6
metadata:
min_tokens: 1
max_tokens: 131071
max_tokens: 65535
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 0

View File

@@ -1,6 +0,0 @@
tensor_parallel_size: 1
max_model_len: 131072
max_num_seqs: 32
reasoning_parser: ernie_x1
tool_call_parser: ernie_x1
load_choices: "default_v1"

View File

@@ -143,9 +143,9 @@ function build_and_install_ops() {
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
if [ "$is_xpu" = "True" ]; then
cd xpu_ops
cd xpu_ops/src
bash build.sh ${TMP_DIR_REAL_PATH}
cd ..
cd ../..
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
if [ "$FD_BUILDING_ARCS" == "" ]; then
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}

View File

@@ -14,7 +14,7 @@
#include "paddle/extension.h"
void set_value_by_flags_and_idx(const bool *stop_flags,
void set_value_by_flag_and_id(const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
set_value_by_flags_and_idx(stop_flags.data<bool>(),
set_value_by_flag_and_id(stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),

View File

@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
not_need_stop[0] = stop_sum < stop_nums[0];
}
void UpdateInputs(const paddle::Tensor &stop_flags,
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"input_ids", "input_ids_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputs));
.SetKernelFn(PD_KERNEL(UpdateInputes));

View File

@@ -140,8 +140,8 @@ void AppendAttentionKernel(
key_cache,
value_cache,
attn_mask,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
@@ -273,15 +273,11 @@ void AppendAttentionKernel(
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
@@ -300,15 +296,11 @@ void AppendAttentionKernel(
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
}
} else {
if (qkv_out_scales) {
@@ -317,6 +309,7 @@ void AppendAttentionKernel(
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -343,6 +336,7 @@ void AppendAttentionKernel(
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,

View File

@@ -52,7 +52,6 @@ __global__ void multi_query_append_attention_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -75,11 +74,6 @@ __global__ void multi_query_append_attention_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -148,7 +142,7 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -428,14 +422,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int32_t attn_mask_len = -1) {
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -452,11 +445,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -523,7 +511,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -914,7 +902,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -973,7 +960,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1089,7 +1075,7 @@ void MultiQueryAppendAttention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
int32_t attn_mask_len;
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
@@ -1148,7 +1134,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1221,7 +1206,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),

View File

@@ -57,7 +57,6 @@ __global__ void multi_query_append_attention_c4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -86,11 +85,6 @@ __global__ void multi_query_append_attention_c4_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -179,7 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -526,14 +520,13 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int32_t attn_mask_len = -1) {
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -556,11 +549,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -647,7 +635,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -1119,7 +1107,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1184,7 +1171,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1313,7 +1299,7 @@ void MultiQueryAppendC4Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
int32_t attn_mask_len;
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
@@ -1379,7 +1365,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1460,7 +1445,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),

View File

@@ -32,15 +32,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads]
const T *__restrict__ cache_v_scale, // [num_kv_heads]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -58,7 +57,6 @@ __global__ void multi_query_append_attention_c8_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -88,40 +86,33 @@ __global__ void multi_query_append_attention_c8_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
@@ -189,7 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -210,13 +201,6 @@ __global__ void multi_query_append_attention_c8_kernel(
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
@@ -298,22 +282,10 @@ __global__ void multi_query_append_attention_c8_kernel(
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -346,7 +318,6 @@ __global__ void multi_query_append_attention_c8_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const int ori_kv_idx_base = kv_idx_base;
kv_idx_base += num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -365,18 +336,6 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -387,9 +346,7 @@ __global__ void multi_query_append_attention_c8_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -506,15 +463,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise=false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -533,14 +489,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int32_t attn_mask_len = -1) {
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -563,39 +518,32 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
@@ -661,7 +609,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -686,13 +634,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
CAUSAL
@@ -775,23 +716,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -824,7 +753,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const uint32_t ori_kv_idx_base = kv_idx_base;
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -843,18 +771,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -865,9 +781,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -981,8 +895,7 @@ template <typename T,
uint32_t NUM_WARP_Q,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
void MultiQueryAppendC8Attention(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &qkv,
@@ -1040,8 +953,7 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
num_frags_z * 16 * sizeof(T) * 2;
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
@@ -1058,9 +970,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1078,9 +988,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1114,9 +1022,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1134,9 +1040,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1171,7 +1075,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1230,7 +1133,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1316,8 +1218,7 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
constexpr uint32_t smem_size =
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1334,9 +1235,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1354,9 +1253,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1372,7 +1269,7 @@ void MultiQueryAppendC8Attention(
}
const int num_chunks = div_up(max_seq_len, chunk_size);
int32_t attn_mask_len;
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
@@ -1398,9 +1295,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1418,9 +1313,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1457,7 +1350,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1532,7 +1424,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1655,7 +1546,6 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out) {
const auto token_num = meta_data.token_nums;
@@ -1664,7 +1554,6 @@ void CascadeAppendAttentionC8Kernel(
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim = meta_data.head_dims;
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
DISPATCH_CAUSAL(
causal,
@@ -1683,46 +1572,43 @@ void CascadeAppendAttentionC8Kernel(
BLOCK_SIZE,
{DISPATCH_BLOCKSHAPE_Q(
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
BLOCK_SIZE,
CAUSAL,
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL,
IsFP8,
IsDynamicC8>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale.get(),
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
out);
})})})})})})})
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
BLOCK_SIZE,
CAUSAL,
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL, IsFP8>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale.get(),
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
out);
})})})})})})
}

View File

@@ -384,113 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
}
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
}
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
@@ -923,8 +816,7 @@ template <uint32_t num_frags_x,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
@@ -968,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
#pragma unroll
@@ -1026,7 +911,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const int32_t attn_mask_len,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
@@ -1044,13 +929,13 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
@@ -1208,9 +1093,7 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool is_scale_channel_wise = false, bool IsFP8=false>
__device__ __forceinline__ void compute_sfm_v_c8(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1252,28 +1135,16 @@ __device__ __forceinline__ void compute_sfm_v_c8(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
@@ -1300,9 +1171,7 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool is_scale_channel_wise = false, bool IsFP8=false>
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1346,28 +1215,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16

View File

@@ -103,7 +103,6 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -265,10 +264,9 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
} else if (cache_quant_type_str == "cache_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
@@ -301,7 +299,6 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_int4_zp") {

View File

@@ -18,53 +18,6 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
// Note(ZKK)
// This function is very easy!
// just make HeadDim data to be new HeadDim data!
template <typename T, int VecSize=8, int HEAD_DIM=128, int NUM_THREADS=32>
__device__ __forceinline__ void apply_rope(
const T* input,
const float* cos_emb,
const float* sin_emb,
T* output,
const int thread_id) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
#pragma unroll
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) {
Load<T, VecSize>(&input[head_bias], &src_vec);
const uint32_t emb_idx = head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(out_vec, &output[head_bias]);
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
@@ -75,7 +28,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -167,6 +120,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) { // q
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
@@ -175,7 +129,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
}
} else { // k
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
@@ -211,7 +164,7 @@ __global__ void append_decode_cache_T_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -317,7 +270,7 @@ __global__ void append_decode_cache_T_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -428,142 +381,6 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int rotary_dim,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec, right_vec;
LoadBiasT left_bias_vec, right_bias_vec;
LoadKVT left_cache_vec, right_cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_head_size = head_size / 2;
const int half_rotary_dim = rotary_dim / 2;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
// const int64_t offset = 2 * hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / half_hidden_size;
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
if (hi < num_heads && h_bias >= half_rotary_dim){
continue;
}
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
uint32_t ori_idx_left =
start_token_idx * hidden_size + hi * head_size + h_bias;
uint32_t ori_idx_right = ori_idx_left + half_head_size;
if (hi < num_heads){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else if (hi < num_heads + kv_num_heads){
if (h_bias < half_rotary_dim){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else{
ori_idx_left = ori_idx_left + half_rotary_dim;
ori_idx_right = ori_idx_left + half_rotary_dim;
}
}
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
if (h_bias < half_rotary_dim){
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// rope
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
} else {
// write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
uint32_t tgt_idx_left =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
if (hi < num_heads + kv_num_heads) {
if (h_bias < half_rotary_dim) {
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}else{
tgt_idx_left = tgt_idx_left + half_rotary_dim;
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
} else {
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
}
}
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
@@ -574,6 +391,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -687,6 +505,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -810,293 +629,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
}
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
T* __restrict__ cache_k_scale,
T* __restrict__ cache_v_scale,
const float* q_norm_weight,
const float* k_norm_weight,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d,
const float rms_norm_eps) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane_id = tid % 32;
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
const int* block_table_now = nullptr;
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
int cache_offset;
if (head_idx < num_heads) {
cache_offset = 0;
} else if (head_idx < num_heads + 2 * kv_num_heads) {
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
}
T *cache_k_scale_now = cache_k_scale + cache_offset;
T *cache_v_scale_now = cache_v_scale + cache_offset;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] =
static_cast<T>(tmp1);
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
}
// qk norm
if (q_norm_weight) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadOutScaleT q_norm_vec;
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
}
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
if (block_offset == 0) {
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
if (head_idx < num_heads + kv_num_heads) {
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim;
block_i < block_size;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&key_cache[tgt_idx + block_i * HeadDim]);
}
} else {
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
const int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
}
}
__syncwarp();
}
constexpr int K_VEC_SIZE = 4;
constexpr int HALF_K_VEC_SIZE = 2;
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
using LoadEmbT = AlignedVector<float, 1>;
LoadKVResT cache_vec;
LoadT src_vec1, src_vec2;
LoadBiasT out_vec1, out_vec2;
LoadEmbT cos_emb_vec1, cos_emb_vec2;
LoadEmbT sin_emb_vec1, sin_emb_vec2;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
T scale = T(1.0f);
const int k_head_idx = head_idx - num_heads;
const int v_head_idx = head_idx - num_heads - kv_num_heads;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
}
float input_left = static_cast<float>(src_vec1[0]);
float input_right = static_cast<float>(src_vec1[1]);
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec1[0] =
static_cast<T>(tmp1);
out_vec1[1] =
static_cast<T>(tmp2);
} else {
out_vec1[0] = src_vec1[0];
out_vec1[1] = src_vec1[1];
}
// rope
input_left = static_cast<float>(src_vec2[0]);
input_right = static_cast<float>(src_vec2[1]);
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec2[0] = static_cast<T>(tmp1);
out_vec2[1] = static_cast<T>(tmp2);
} else {
out_vec2[0] = src_vec2[0];
out_vec2[1] = src_vec2[1];
}
if (k_norm_weight) {
if (head_idx < num_heads + kv_num_heads) {
LoadOutScaleT k_norm_vec1, k_norm_vec2;
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
// qk norm
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
}
}
}
// reduce max, 1 head per warp
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(out_vec1[i]));
local_max = __hmax(local_max, __habs(out_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}
scale = __hdiv(448, local_max);
if (lane_id == 0) {
if (head_idx < num_heads + kv_num_heads) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}
#pragma unroll
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
}
if (head_idx < num_heads + kv_num_heads) {
const int start_block_16 =
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
const uint32_t tgt_cache_idx =
block_idx * kv_num_heads * block_size * HeadDim +
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
} else {
const uint32_t base_tgt_cache_idx =
block_idx * kv_num_heads * HeadDim * block_size +
kv_head_idx * HeadDim * block_size +
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
block_offset % 8 / 2 * 4 // per 4
+ block_offset % 16 / 8 * 2 // per 2
+ block_offset % 2; // per 1
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
value_cache[tgt_cache_idx1] = cache_vec[0];
value_cache[tgt_cache_idx2] = cache_vec[1];
value_cache[tgt_cache_idx3] = cache_vec[2];
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
__global__ void append_decode_cache_int8_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
@@ -1107,6 +639,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1144,18 +677,44 @@ __global__ void append_decode_cache_int8_rope_kernel(
if (head_idx < num_heads) {
// q
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
apply_rope<T, VecSize, HeadDim, 32>(
qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
@@ -1330,6 +889,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1634,6 +1194,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1935,7 +1496,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2332,7 +1893,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2373,18 +1934,44 @@ __global__ void append_decode_cache_int4_rope_kernel(
if (head_idx < num_heads) {
// q
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
apply_rope<T, VecSize, HeadDim, 32>(
qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
@@ -2604,7 +2191,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2935,7 +2522,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -3308,7 +2895,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]

View File

@@ -21,6 +21,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -58,6 +59,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -82,6 +84,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -94,7 +97,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int rotary_dim,
const int block_size,
const int bsz,
const cudaStream_t& stream,
@@ -118,6 +120,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -134,34 +137,13 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
if (rotary_dim < dim_head){
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}else{
append_decode_cache_T_neox_rope_kernel<T, PackSize>
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -175,7 +157,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
elem_nums,
kv_num_heads,
rope_3d);
}
}
} else {
if (qkv_out_scales) {
@@ -186,6 +167,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -208,6 +190,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -231,6 +214,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -263,6 +247,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -288,6 +273,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -313,6 +299,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -338,6 +325,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -363,6 +351,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -397,6 +386,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -424,6 +414,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -451,6 +442,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -478,6 +470,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -504,6 +497,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -540,20 +534,11 @@ void DecoderWriteCacheWithRoPEKernel(
const float* cos_emb =
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
const float* sin_emb;
int rotary_dim = dim_head;
if (rotary_embs) {
sin_emb =
use_neox_rotary_style
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < dim_head){
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
}
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
}
}
if (q_norm_weight && k_norm_weight) {
@@ -564,6 +549,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -586,39 +572,9 @@ void DecoderWriteCacheWithRoPEKernel(
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
} else {
if (cache_quant_type_str == "none") {
@@ -628,6 +584,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -642,7 +599,6 @@ void DecoderWriteCacheWithRoPEKernel(
num_heads,
kv_num_heads,
dim_head,
rotary_dim,
block_size,
bsz,
stream,
@@ -660,6 +616,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -692,6 +649,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -725,6 +683,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -750,36 +709,6 @@ void DecoderWriteCacheWithRoPEKernel(
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
nullptr,
nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -787,6 +716,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -834,6 +764,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -863,6 +794,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -891,6 +823,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -919,6 +852,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,

View File

@@ -23,6 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,

View File

@@ -449,8 +449,8 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim;
for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) {
int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
@@ -900,74 +900,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
}
}
template <typename T, int VecSize = 1>
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales,
const T *qkv_biases,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int head_dim,
const int rotary_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int rotary_dim_half = rotary_dim / 2;
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / rotary_dim_half;
const int h_bias = bias % rotary_dim_half;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
const int base_idx_left =
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
const int base_idx_right = base_idx_left + rotary_dim_half;
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
const float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
}
}
template <typename T, int VecSize = 1>
__global__ void cache_kernel(
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
@@ -1300,411 +1232,6 @@ __global__ void append_write_cache_kv_c8_qkv(
}
}
template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t NUM_WARPS,
bool is_need_kv_quant,
bool IsFP8 = true>
__global__ void append_write_cache_kv_c8_qkv_dynamic(
uint8_t *__restrict__ cache_k,
uint8_t *__restrict__ cache_v,
const T *__restrict__ qkv_input,
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids,
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t pad_len = BLOCK_SIZE;
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const T cache_k_scale = cache_k_scales[kv_head_idx];
const T cache_v_scale = cache_v_scales[kv_head_idx];
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids[btid];
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
if (seq_len_this_time <= 0) {
return;
}
const int *block_table_now = nullptr;
block_table_now = block_tables + batch_id * max_blocks_per_seq;
const uint32_t num_rows_per_block =
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
const uint32_t start_len = seq_lens_decoder[batch_id];
const uint32_t bf_pad_len = start_len % pad_len;
const uint32_t start_len_pad = start_len - bf_pad_len;
const uint32_t end_len = start_len + seq_len_this_time;
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_scale_smem[BLOCK_SIZE];
if (tile_start >= start_len) {
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
// reset k
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
uint32_t tgt_idx =
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
tid % num_vecs_per_head_k * KV_VEC_SIZE;
for (int block_i = tid / num_vecs_per_head_k;
block_i < BLOCK_SIZE;
block_i += num_token_each_time_k) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&cache_k[tgt_idx + block_i * HEAD_DIM]);
}
// reset v
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
tgt_idx =
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
tid % num_vecs_per_head_v * KV_VEC_SIZE;
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
block_i += num_token_each_time_v) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
}
}
smem_t k_smem(k_smem_ori);
smem_t v_smem(v_smem_ori);
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
/*
0 | 1
2 | 3
*/
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
/*
0 | 2
1 | 3
*/
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, wid * num_frags_v * 2 + tid / 16);
// load kv gmem to smem
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
tile_id * num_rows_per_block +
wid * num_frags_z * 16 + tid / 8;
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
(num_heads + kv_head_idx) * kv_h_stride +
tid % 8 * num_elems_per_128b<T>();
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
tid % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y / 4;
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
if (chunk_start >= start_len && chunk_start < end_len) {
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
}
kv_smem_offset_w =
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
k_read_idx += 8 * num_elems_per_128b<T>();
v_read_idx += 8 * num_elems_per_128b<T>();
}
kv_smem_offset_w =
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
2 * num_frags_y;
chunk_start += 4;
k_read_idx +=
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
v_read_idx +=
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
}
}
commit_group();
wait_group<0>();
__syncthreads();
// reduce scale
// 16 rows per warp
uint32_t kv_reduce_frag[4];
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
T k_local_max_value[num_frags_z * 2];
T v_local_max_value[num_frags_z * 2];
#pragma unroll
for (int i = 0; i < num_frags_z * 2; i++) {
k_local_max_value[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < num_frags_z * 2; i++) {
v_local_max_value[i] = -INFINITY;
}
const int num_kv_heads = gridDim.z;
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
T *cache_k_scale_now = cache_k_scales + scale_offset;
T *cache_v_scale_now = cache_v_scales + scale_offset;
// k scale
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// reduce per thread, 4 threads each row
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
#pragma unroll
for (int i = 0; i < 4; i++) {
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
}
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
// reduce per row
for (int i = 0; i < 2; i++) {
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
// used for quant
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
}
// store
if (tid % 4 == 0) {
const int offset_now = wid * num_frags_z * 16 + tid / 4;
// used for dequant
if (tile_start + offset_now >= start_len) {
if (tile_start + offset_now < end_len) {
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
} else {
cache_k_scale_now[offset_now] = 0;
}
}
if (tile_start + offset_now + 8 >= start_len) {
if (tile_start + offset_now + 8 < end_len) {
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
} else {
cache_k_scale_now[offset_now + 8] = 0;
}
}
}
__syncthreads();
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
}
// v scale
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// reduce per thread, 4 threads each row
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
#pragma unroll
for (int i = 0; i < 4; i++) {
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
}
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
// reduce per row
for (int i = 0; i < 2; i++) {
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
}
// store
if (tid % 4 == 0) {
const int offset_now = wid * num_frags_z * 16 + tid / 4;
// used for dequant
if (tile_start + offset_now >= start_len) {
if (tile_start + offset_now < end_len) {
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
} else {
cache_v_scale_now[offset_now] = 0;
v_scale_smem[offset_now] = 0;
}
}
if (tile_start + offset_now + 8 >= start_len) {
if (tile_start + offset_now + 8 < end_len) {
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
} else {
cache_v_scale_now[offset_now + 8] = 0;
v_scale_smem[offset_now + 8] = 0;
}
}
}
__syncthreads();
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
}
__syncthreads();
// mask, quant, store
using LoadKVT = AlignedVector<uint8_t, 4>;
LoadKVT cache_vec1;
LoadKVT cache_vec2;
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
uint32_t kv_frag[4];
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t write_b_stride = HEAD_DIM;
const uint32_t write_d_stride = BLOCK_SIZE;
uint32_t k_write_idx = block_id * write_n_stride +
kv_head_idx * write_h_stride +
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t k_write_idx_now = k_write_idx_now_z +
fy % 2 * 8 * write_b_stride +
fy / 2 * 32; // + fy % 2 * 16;
// load
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// quant
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
if (bf_pad_len != 0) {
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
}
#pragma unroll
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
uint8_t uint_quant_value;
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
chunk_start_k + (v_id / 4) * 8 < end_len) {
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
} else {
uint_quant_value = 0;
}
if (bf_pad_len != 0) {
if (v_id < 4) {
cache_vec1[v_id] |= uint_quant_value;
} else {
cache_vec2[v_id % 4] |= uint_quant_value;
}
} else {
if (v_id < 4) {
cache_vec1[v_id] = uint_quant_value;
} else {
cache_vec2[v_id - 4] = uint_quant_value;
}
}
}
// store
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
k_smem_offset_r =
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
2 * num_frags_y;
chunk_start_k += 16;
}
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
uint32_t v_write_idx = block_id * write_n_stride +
kv_head_idx * write_h_stride +
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
T v_scales[num_frags_z_v * 4];
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
const int offset = v_i * 16;
const int t_offset = tid % 4 * 2;
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
}
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
uint32_t v_write_idx_now = v_write_idx_now_v +
fz % 2 * 8 * write_d_stride +
fz / 2 * 32; // + fz % 2 * 16;
// load
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
// quant
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
if (bf_pad_len != 0) {
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
}
#pragma unroll
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
uint8_t uint_quant_value;
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
// store now
} else {
uint_quant_value = 0;
}
if (bf_pad_len != 0) {
if (v_id < 4) {
cache_vec1[v_id] |= uint_quant_value;
} else {
cache_vec2[v_id % 4] |= uint_quant_value;
}
} else {
if (v_id < 4) {
cache_vec1[v_id] = uint_quant_value;
} else {
cache_vec2[v_id % 4] = uint_quant_value;
}
}
}
// store
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
chunk_start_v += 16;
v_smem_offset_r =
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
}
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
v_smem_offset_r, wid * num_frags_v + fy) -
16 * num_frags_z_v * num_vecs_per_head;
chunk_start_v -= 16 * num_frags_z_v;
}
}
// Write Cache KV in Append
template <typename T,
uint32_t num_frags_y,
@@ -2228,7 +1755,6 @@ void gqa_rotary_qk_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const int rotary_dim,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false) {
@@ -2309,38 +1835,7 @@ void gqa_rotary_qk_variable(
dim_head,
rope_3d);
} else {
if (rotary_dim < dim_head){
PD_CHECK((rotary_dim / 2) % PackSize == 0);
elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
if (use_neox_style) {
elem_nums /= 2;
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
rotary_emb + input_output_len * rotary_dim / 2,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rotary_dim,
rope_3d);
}else{
GQANeoxVariableLengthRotaryKernel<T, PackSize>
GQANeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
@@ -2358,7 +1853,6 @@ void gqa_rotary_qk_variable(
seq_len,
dim_head,
rope_3d);
}
}
}
}
@@ -2512,11 +2006,10 @@ void CascadeAppendWriteCacheKVC8QKV(
int num_blocks_x_cpu,
int max_seq_len,
bool is_scale_channel_wise,
const std::string& cache_quant_type,
const bool is_fp8,
cudaStream_t &stream,
paddle::Tensor *cache_k_out,
paddle::Tensor *cache_v_out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
@@ -2534,77 +2027,49 @@ void CascadeAppendWriteCacheKVC8QKV(
dim3 blocks(32, num_warps);
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
if (cache_quant_type != "block_wise_fp8") {
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, false>;
if (cache_quant_type == "cache_fp8") {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
}
if (is_scale_channel_wise) {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
false>;
}
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
qkv.data<T>(),
cache_k_scale.data<T>(),
cache_v_scale.data<T>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
} else {
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, false>;
if (is_fp8) {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
}
if (is_scale_channel_wise) {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
false>;
}
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
qkv.data<T>(),
cache_k_scale.data<T>(),
cache_v_scale.data<T>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
}
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>

View File

@@ -55,19 +55,9 @@ void EncoderWriteCacheWithRopeKernel(
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;
bool is_scale_channel_wise = false;
int rotary_dim = head_dim;
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
is_scale_channel_wise = true;
}
if (rotary_embs){
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < head_dim){
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
}
}
}
if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
@@ -135,7 +125,6 @@ void EncoderWriteCacheWithRopeKernel(
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
rotary_dim,
stream,
use_neox_style,
rope_3d);
@@ -178,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel(
stream,
key_cache_out,
value_cache_out);
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
@@ -198,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel(
num_blocks,
max_seq_len,
is_scale_channel_wise,
cache_quant_type_str,
cache_quant_type_str == "cache_fp8",
stream,
key_cache_out,
value_cache_out);

View File

@@ -11,11 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "utils.cuh"
template <int THREADBLOCK_SIZE>
__global__ void
@@ -117,93 +116,6 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
max_len_tensor.data<int>(), batch_size);
}
template <uint32_t config_size>
__global__ void search_chunk_size_for_mla(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ num_blocks_x,
int *__restrict__ res_chunk_size,
const int bsz,
const int set_chunk_size,
const int block_size,
const int sm_cout) {
const uint32_t conf_id = threadIdx.x;
int gridx = 0;
if (set_chunk_size > 0 && conf_id == 0) {
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
gridx += loop_times;
}
*num_blocks_x = gridx;
*res_chunk_size = set_chunk_size;
} else if (conf_id < config_size) {
__shared__ int gridx_shared[config_size];
// chunk_size is a multiple of 64
const int chunk_size = block_size << conf_id;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
gridx += loop_times;
}
gridx_shared[conf_id] = gridx;
__syncthreads();
if (threadIdx.x == 0) {
uint32_t res_id = 0;
uint32_t max_last_wave_block = 0;
for (uint32_t i = 1; i < config_size; i++) {
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
if (last_wave_block >= max_last_wave_block) {
res_id = i;
max_last_wave_block = last_wave_block;
}
}
*num_blocks_x = gridx_shared[res_id];
*res_chunk_size = block_size << res_id;
}
}
}
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
const int bsz,
const int chunk_size) {
if (threadIdx.x == 0) {
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
}
}
}
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
@@ -279,23 +191,14 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
}
}
void GetBlockShapeAndSplitKVBlock(
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
@@ -320,120 +223,31 @@ void GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
// decoder
if (max_dec_len_this_time > 0) {
const bool mla_use_tensorcore = GetMlaUseTensorcore();
if (mla_use_tensorcore && group_size <= 64) {
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
int device;
cudaGetDevice(&device);
int sm_cout;
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
constexpr int config_size =
12; // search space for chunk size:[64, 128, 256, ... 131072]
search_chunk_size_for_mla<config_size>
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_num_blocks_device.data<int>(),
decoder_chunk_size_device.data<int>(),
bsz,
set_chunk_size,
block_size,
sm_cout);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
auto decoder_chunk_size_cpu =
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
stream));
split_block_for_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
bsz,
chunk_size);
} else {
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_device.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}
// encoder
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
@@ -444,12 +258,16 @@ void GetBlockShapeAndSplitKVBlock(
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
@@ -457,9 +275,54 @@ void GetBlockShapeAndSplitKVBlock(
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}
if (max_just_dec_len_this_time > 0) {
// Clear buffer
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}
return {
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu, /*cpu*/
};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
@@ -469,20 +332,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_cpu",
"decoder_num_blocks_device",
"decoder_chunk_size_device",
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
"decoder_num_blocks_x_cpu",
"max_len_tensor_cpu"
})
.Outputs({
paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks_x_cpu"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks_x_cpu"),
"max_len_kv_cpu"
})
.Attrs({
"encoder_block_shape_q: int",

View File

@@ -217,7 +217,7 @@ __global__ void append_cache_kv_c16(
// load k_smem 64 rows 128 cols
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -235,7 +235,7 @@ __global__ void append_cache_kv_c16(
// deal k_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
uint32_t col_idx = fy * 16 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// layout
@@ -278,7 +278,7 @@ __global__ void append_cache_kv_c16(
// load v_smem 64 rows 128 cols
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -296,7 +296,7 @@ __global__ void append_cache_kv_c16(
// deal v_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
uint32_t col_idx = fy * 16 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
// layout
@@ -400,7 +400,7 @@ __global__ void append_cache_kv_c8(
// load v_smem 64 rows, 128 cols
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -418,7 +418,7 @@ __global__ void append_cache_kv_c8(
// deal k_smem 64 rows, 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
uint32_t col_idx = fy * 32 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
// layout
@@ -466,7 +466,7 @@ __global__ void append_cache_kv_c8(
tid % 4 * num_elems_per_128b<CacheT>();
// load v_smem 128 rows 64 cols
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -485,7 +485,7 @@ __global__ void append_cache_kv_c8(
// deal v_smem 128 rows 64 cols
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
// layout
@@ -614,7 +614,7 @@ __global__ void append_cache_kv_c4(
// load k_smem 64 rows 128 cols
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -632,7 +632,7 @@ __global__ void append_cache_kv_c4(
// deal k_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
uint32_t col_idx = fy * 64 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
@@ -685,7 +685,7 @@ __global__ void append_cache_kv_c4(
tid % 2 * num_elems_per_128b<CacheT>();
// load v_smem 128 rows 64 rows
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -704,7 +704,7 @@ __global__ void append_cache_kv_c4(
// deal v_smem 128 rows 64 cols
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
// layout
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
meta_data,
*const_cast<paddle::Tensor*>(&key_cache),
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
kv_num_blocks_data,
max_seq_len,
false, // is_scale_channel_wise
cache_quant_type,
cache_quant_type == "cache_fp8", // is_fp8
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));

View File

@@ -18,168 +18,6 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1, typename InT = T>
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int output_inner_dim,
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadInT src_vec;
LoadFloat scale_vec;
LoadT bias_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec;
LoadFloat k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
const int half_head_size = head_size / 2;
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
const int token_id = linear_index / hidden_size;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
const int write_q_idx =
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
}
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (qkv_out_scales) {
input_left *= scale_vec[2 * i];
input_right *= scale_vec[2 * i + 1];
}
if (qkv_biases) {
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
}
if (hi < num_heads + gqa_group_size) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
}
}
if (hi < (num_heads + gqa_group_size)) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
} else {
// write k/v
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
kv_head_idx * block_size * head_size +
block_offset * head_size + h_bias);
// write
if (hi < num_heads + gqa_group_size) {
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
}
}
}
}
template <int VecSize = 4, int HeadDim = 128>
__global__ void append_clear_cache_int8_block(
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
@@ -355,8 +193,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -416,9 +253,8 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -490,8 +326,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -555,9 +390,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * head_size + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -642,8 +476,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -689,9 +522,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
}
@@ -751,11 +583,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale;
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
} else {
scale = __ldg(&cache_v_scales[kv_head_idx]);
@@ -877,8 +708,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -927,9 +757,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
&left_out_scale_vec);
@@ -1024,11 +853,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1260,8 +1088,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1318,9 +1145,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -1409,11 +1235,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// &out_scale_vec2);
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -1606,8 +1431,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1757,11 +1581,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
&right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],

View File

@@ -15,78 +15,6 @@
#include "speculate_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
rope_3d);
}
}
// rope + write
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope(const QKV_TYPE* qkv,
@@ -111,8 +39,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
@@ -146,8 +73,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_rope_kernel<T, PackSize>
<<<grid_size, threads_per_block, 0, stream>>>(
@@ -170,8 +96,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
@@ -200,8 +125,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -243,8 +167,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -268,8 +191,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
@@ -300,8 +222,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -345,8 +266,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -372,8 +292,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
template <typename T, typename QKV_TYPE>
@@ -394,15 +313,11 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -427,185 +342,142 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
rms_norm_eps,
rope_3d);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}
@@ -628,15 +500,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -658,15 +526,11 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -687,15 +551,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
@@ -718,12 +578,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -35,12 +35,8 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -56,7 +56,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -104,6 +103,5 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -99,6 +98,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -441,15 +441,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
if (is_dynamic_cfp8) { \
constexpr bool IsDynamicC8 = true; \
__VA_ARGS__ \
} else { \
constexpr bool IsDynamicC8 = false; \
__VA_ARGS__ \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \

View File

@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
@@ -105,7 +105,7 @@ void AppendAttentionWithOutput(
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &fmha_out,
const paddle::optional<paddle::Tensor> &rotary_embs,
@@ -255,8 +255,7 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size);
const int estimate_total_token_nums);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
@@ -299,23 +298,14 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
const int layer_id);
void GetBlockShapeAndSplitKVBlock(
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
@@ -388,11 +378,9 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens);
const int block_size);
paddle::Tensor
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
@@ -416,8 +404,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
paddle::Tensor &text_input,
paddle::Tensor &image_input);
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,
@@ -475,18 +463,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks_device,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -714,22 +707,6 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens);
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
@@ -773,7 +750,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
@@ -787,8 +763,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
const bool splitwise_prefill);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
@@ -1005,7 +980,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
py::arg("block_size"),
"per token per block quant and padding transpose scale");
"per token per block quant and padding tranpose scale");
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
py::arg("recv_expert_count"), py::arg("block_size"),
@@ -1048,7 +1023,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
/**
* moe/fused_moe/moe_expert_ffn_wint2.cu
* moe/fused_moe/moe_ffn_wint2.cu
* moe_expert_ffn_wint2
*/
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
@@ -1253,8 +1228,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");

View File

@@ -89,11 +89,11 @@ public:
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of warp-level GEMM operations per load for B
/// Number of warp-level GEMM oeprations per load for B
static constexpr int kWarpGemmIterationsPerLoadForB =
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");

View File

@@ -117,7 +117,7 @@ class LeftGELUAndMul {
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to internal compute numeric type
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_to_compute;

View File

@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to internal compute numeric type
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_to_compute;

View File

@@ -92,7 +92,7 @@ class DualMmaBase {
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);

View File

@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
iterator_C_.clear_mask();
}
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
// adding elementwise beta, we keep this here for future usage beta_ =
// adding elementwise beta, we keep this here for future useage beta_ =
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
// iterator_C_.clear_mask();

View File

@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:

View File

@@ -64,7 +64,7 @@ template <
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
/// Operation perfomed by GEMM
typename Operator,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.

View File

@@ -133,7 +133,7 @@ public:
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
static constexpr int kNumKIterationsPerWarpBLoad =

View File

@@ -509,7 +509,7 @@ public:
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
// TODO(wangbojun) lds_converter can be remove for int8 B input
// TOOD(wangbojun) lds_converter can be remove for int8 B input
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);

View File

@@ -96,7 +96,7 @@ public:
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
static constexpr int kNumKIterationsPerWarpBLoad =

View File

@@ -646,7 +646,7 @@ public:
// );
// }
}
// TODO(wangbojun) lds_converter can be remove for int8 B input
// TOOD(wangbojun) lds_converter can be remove for int8 B input
// int4
// typename TransformBAfterLDS::result_type converted_frag_B =
// lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);

View File

@@ -59,15 +59,6 @@ inline uint32_t get_cascade_attention_num_threads() {
inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env));
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}
inline int get_mla_dec_chunk_size(int bsz) {
static const char* mla_dec_chunk_size_env =
std::getenv("FLAGS_mla_dec_chunk_size");
static const int mla_dec_chunk_size =
mla_dec_chunk_size_env == nullptr
? -1
: std::stoi(std::string(mla_dec_chunk_size_env));
return bsz > 1 ? mla_dec_chunk_size : 64;
}

View File

@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",

View File

@@ -563,11 +563,3 @@ inline int GetSMVersion() {
return sm_version;
}
inline bool GetMlaUseTensorcore() {
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
const bool mla_use_tensorcore =
flags_mla_use_tensorcore && enable_mla_tensorcore;
return mla_use_tensorcore;
}

View File

@@ -171,7 +171,7 @@ struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:

View File

@@ -30,12 +30,10 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
std::optional<paddle::Tensor> const& maybe_token_scales,
std::string maybe_schedule) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
std::optional<int64_t> maybe_group_size_opt;
std::optional<std::string> maybe_schedule_opt;
if (maybe_schedule == "") {
maybe_schedule_opt = std::nullopt;
} else {
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
}
return machete::mm_dispatch({.A = A,
.B = B,
@@ -65,8 +63,6 @@ std::vector<paddle::Tensor> MacheteMMKernel(
paddle::DataType maybe_out_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}

View File

@@ -51,8 +51,6 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}

View File

@@ -70,6 +70,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -77,8 +78,9 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const paddle::Tensor& decoder_chunk_size_device,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
@@ -95,12 +97,14 @@ void BatchMLAWithPagedKVCacheKernel(
const auto q_head_num = meta_data.q_num_heads;
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
const auto max_block_num = bsz * max_block_num_per_seq;
const uint32_t chunk_size = get_max_partition_size(bsz);
int q_head_dim = meta_data.head_dims;
int k_head_dim = meta_data.head_dims;
int v_head_dim = meta_data.head_dims_v;
// int num_chunks = max_dec_len / chunk_size;
int num_chunks = div_up(max_seq_len, 64);
int num_chunks = div_up(max_dec_len, chunk_size);
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
@@ -123,14 +127,14 @@ void BatchMLAWithPagedKVCacheKernel(
params.d = reinterpret_cast<float*>(d_tmp->ptr());
params.block_tables = const_cast<int*>(block_tables.data<int>());
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
params.chunk_size_device =
const_cast<int*>(decoder_chunk_size_device.data<int>());
params.num_blocks_x_int = num_blocks_x;
params.q_stride_bsz = q_head_num * q_head_dim;
params.q_stride_head_num = q_head_dim;
params.kv_stride_block_num = block_size * k_head_dim;
@@ -147,6 +151,7 @@ void BatchMLAWithPagedKVCacheKernel(
params.block_size = block_size;
params.max_draft_token_num = draft_token_num;
params.sm_scale = softmax_scale;
params.chunk_size = chunk_size;
params.chunk_num = num_chunks;
if (q_head_dim == 576) {
@@ -171,6 +176,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -178,8 +184,9 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const paddle::Tensor& decoder_chunk_size_device,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
@@ -203,6 +210,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -210,8 +218,9 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const paddle::Tensor& decoder_chunk_size_device,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,

View File

@@ -47,6 +47,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -54,8 +55,9 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const paddle::Tensor& decoder_chunk_size_device,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,

View File

@@ -128,13 +128,12 @@ struct CollectiveMainloop {
DTypeMD const* d_ptr;
IdType const* kv_block_tables;
IdType const* seq_lens_this_time;
// IdType const* seq_lens_encoder;
IdType const* seq_lens_encoder;
IdType const* seq_lens_decoder;
IdType const* cumsum_q_seqlens;
IdType const* batch_ids;
IdType const* tile_ids_per_batch;
IdType const* num_blocks_x;
IdType const* chunk_size_device;
float sm_scale;
int bsz;
int max_block_num;
@@ -145,7 +144,7 @@ struct CollectiveMainloop {
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
// int chunk_size;
int chunk_size;
int chunk_num;
int max_draft_token_num;
};
@@ -161,13 +160,12 @@ struct CollectiveMainloop {
DTypeMD* d_ptr;
IdType* kv_block_tables;
IdType* seq_lens_this_time;
// IdType* seq_lens_encoder;
IdType* seq_lens_encoder;
IdType* seq_lens_decoder;
IdType* cumsum_q_seqlens;
IdType* batch_ids;
IdType* tile_ids_per_batch;
IdType* num_blocks_x;
IdType* chunk_size_device;
float sm_scale;
int bsz;
int max_block_num;
@@ -178,7 +176,7 @@ struct CollectiveMainloop {
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
// int chunk_size;
int chunk_size;
int chunk_num;
int max_draft_token_num;
TMA_KV tma_load_KV;
@@ -200,13 +198,12 @@ struct CollectiveMainloop {
const_cast<DTypeMD*>(args.d_ptr),
const_cast<IdType*>(args.kv_block_tables),
const_cast<IdType*>(args.seq_lens_this_time),
// const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_decoder),
const_cast<IdType*>(args.cumsum_q_seqlens),
const_cast<IdType*>(args.batch_ids),
const_cast<IdType*>(args.tile_ids_per_batch),
const_cast<IdType*>(args.num_blocks_x),
const_cast<IdType*>(args.chunk_size_device),
args.sm_scale,
args.bsz,
args.max_block_num,
@@ -217,7 +214,7 @@ struct CollectiveMainloop {
args.kv_stride_block_size,
args.o_stride_bsz,
args.o_stride_head_num,
// args.chunk_size,
args.chunk_size,
args.chunk_num,
args.max_draft_token_num,
tma_load_KV
@@ -284,9 +281,9 @@ struct CollectiveMainloop {
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
@@ -325,9 +322,9 @@ struct CollectiveMainloop {
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));

View File

@@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
@@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
@@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
@@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {

View File

@@ -62,12 +62,13 @@ struct Params {
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token;
@@ -75,7 +76,7 @@ struct Params {
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
@@ -95,7 +96,9 @@ struct Params {
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_size;
int chunk_num;
int num_blocks_x_int;
float sm_scale;
};
@@ -115,7 +118,7 @@ struct Params {
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
@@ -134,7 +137,6 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
const int num_blocks_x = mainloop_params.num_blocks_x[0];
const int chunk_size = mainloop_params.chunk_size_device[0];
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
@@ -203,10 +205,58 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
const int block_id = blockIdx.x;
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
@@ -259,12 +309,76 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
} else {
const int block_id = blockIdx.x;
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
@@ -315,15 +429,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
start_token_idx,
tile_id,
seq_len_decoder_now,
chunk_size,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
@@ -346,12 +460,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_encoder,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.chunk_size_device,
params.sm_scale,
params.bsz,
params.max_block_num,
@@ -362,6 +476,7 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_size,
params.chunk_num,
params.max_draft_token_num
});
@@ -385,9 +500,13 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
// by the graph.
dim3 grid_dims = {multiprocessor_count, 1, 1};
int gridx;
if constexpr(USE_FIXED_BLOCK) {
gridx = multiprocessor_count;
} else {
gridx = params.num_blocks_x_int;
}
dim3 grid_dims = {gridx, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
@@ -398,38 +517,37 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE,
vec_size,
blocky,
KernelTraits::HEAD_DIM_VO>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_id_per_token,
params.chunk_size_device,
reinterpret_cast<NV_TYPE *>(params.O),
params.q_num_head,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num);
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.seq_lens_encoder,
params.cumsum_q_seqlens,
params.batch_id_per_token,
reinterpret_cast<NV_TYPE*>(params.O),
params.chunk_num,
params.q_num_head,
params.chunk_size,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num
);
}
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,

View File

@@ -249,16 +249,18 @@ struct prefill_softmax_state_t {
};
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder,
const int *__restrict__ cu_seqlens_q,
const int * __restrict__ batch_id_per_token,
const int * __restrict__ chunk_size_device,
T * __restrict__ out, // [token_num, num_heads, head_dim]
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim,
const int token_num,
const int bsz,
@@ -269,15 +271,13 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
if (bid == -1) continue;
const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue;
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
// not need merge
continue;

View File

@@ -383,7 +383,7 @@ __global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attenti
template<typename Kernel_traits, typename ParamType>
inline __device__ float calculate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType &params, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType &params, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
const int32_t bi = blockIdx.z;
@@ -524,7 +524,7 @@ __global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_a
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
float inv_global_exp_sum = calculate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
using T_vec = Vec<cuteType, kNReducePacksize>;

View File

@@ -40,7 +40,7 @@ __global__ void write_encoder_cachekv_c16(
if (seq_len == 0) return;
const int remain_tokens = seq_len - block_idx;
const int ramian_tokens = seq_len - block_idx;
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
@@ -51,7 +51,7 @@ __global__ void write_encoder_cachekv_c16(
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < remain_tokens) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
}
}
@@ -62,7 +62,7 @@ __global__ void write_encoder_cachekv_c16(
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < remain_tokens) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
}
}

View File

@@ -50,14 +50,14 @@ __global__ void get_kv_from_cache_c16_kernel(
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
const int remain_tokens = seq_len - base_token_idx;
const int ramian_tokens = seq_len - base_token_idx;
if (bidh < kv_head_num) {
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < remain_tokens) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
}
}
@@ -67,7 +67,7 @@ __global__ void get_kv_from_cache_c16_kernel(
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < remain_tokens) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
}
}

View File

@@ -448,71 +448,137 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
auto place = input.place();
const int gridx = min(132 * 8, num_rows);
if (moe_quant_type == "w4a8") {
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<int8_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);)
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<int8_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<int8_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
}
} else if (moe_quant_type == "w4afp8") {
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);)
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
}
} else {
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);)
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
}
}
}

View File

@@ -872,14 +872,16 @@ void MoeFastHardamardWrapper(const T *x_data,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT* out,
cudaStream_t &stream) {
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
constexpr int kThreads = 128;
if (FLAGS_hardamard_use_diagonal_block_matrix) {
const int VecSize = hadamard_block_size / kThreads;
const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1
const int logN = int(ceil(std::log2(kThreads * VecSize)));
constexpr int kNChunks = 1;
DISPATCH_SP_VS(VecSize, VEC_SIZE, {
@@ -989,7 +991,6 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float16 *out,
cudaStream_t &stream
);
@@ -1008,7 +1009,6 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream
);
@@ -1027,7 +1027,6 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::bfloat16 *out,
cudaStream_t &stream
);
@@ -1046,7 +1045,6 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream
);

View File

@@ -32,6 +32,5 @@ void MoeFastHardamardWrapper(const T *x_data,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT* out,
cudaStream_t &stream);

View File

@@ -236,7 +236,7 @@ public:
num_experts, k, stream);
}
topk_gating_softmax_kernelLauncher<float, int>(
topk_gating_softmax_kernelLauncher<float, int>::run(
gating_output, nullptr, expert_scales_float, softmax_out_,
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
num_experts, k, group_moe, stream);
@@ -248,7 +248,7 @@ public:
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
false, stream);
initialize_moe_routing_kernelLauncher(
initialize_moe_routing_kernelLauncher<T>::run(
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
hidden_size, k, stream);
@@ -335,14 +335,14 @@ public:
num_experts, down_proj_quant_args, stream);
}
finalize_moe_routing_kernelLauncher(
finalize_moe_routing_kernelLauncher<T>::run(
fc2_result, output_, fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
routed_scaling_factor, stream);
} else {
finalize_moe_routing_kernelLauncher(
finalize_moe_routing_kernelLauncher<T>::run(
// fc2_result,
fc1_out, output_,
fc1_expert_biases, // fc2_expert_biases,

View File

@@ -1139,7 +1139,9 @@ void topk_gating_softmax_launcher_helper(const T* input,
}
template <typename T, typename IdxT = int>
void topk_gating_softmax_kernelLauncher(const T* input,
struct topk_gating_softmax_kernelLauncher{
static void run(const T* input,
const T* gating_correction_bias,
T* output,
T* softmax,
@@ -1219,6 +1221,7 @@ void topk_gating_softmax_kernelLauncher(const T* input,
}
}
}
};
// ========================== Permutation things
// =======================================
@@ -1313,7 +1316,9 @@ __global__ void initialize_moe_routing_kernel(
}
template <typename T, typename OutT = T>
void initialize_moe_routing_kernelLauncher(
struct initialize_moe_routing_kernelLauncher{
static void run(
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
@@ -1356,6 +1361,7 @@ void initialize_moe_routing_kernelLauncher(
num_rows * k);
}
}
};
// ============================== Infer GEMM sizes
// =================================
@@ -1466,7 +1472,8 @@ __global__ void finalize_moe_routing_kernel(
}
template <typename T>
void finalize_moe_routing_kernelLauncher(
struct finalize_moe_routing_kernelLauncher{
static void run(
const T* expanded_permuted_rows,
T* reduced_unpermuted_output,
const T* bias,
@@ -1498,4 +1505,5 @@ void finalize_moe_routing_kernelLauncher(
routed_scaling_factor,
num_rows);
}
};
} // namespace phi

View File

@@ -36,9 +36,6 @@ void MoeDispatchKernel(
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
using namespace phi;
if (num_rows == 0){
return;
}
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -83,7 +80,7 @@ void MoeDispatchKernel(
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
// (TODO: check fill sucess ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
@@ -103,7 +100,7 @@ void MoeDispatchKernel(
softmax_out_ = nullptr;
}
topk_gating_softmax_kernelLauncher(
topk_gating_softmax_kernelLauncher<float, int>::run(
gating_output.data<float>(),
gating_correction_bias ? gating_correction_bias.get().data<float>()
: nullptr,
@@ -117,13 +114,13 @@ void MoeDispatchKernel(
if (w4a8_in_scale) {
if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher(
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher(
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
permuted_rows_, expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
@@ -131,7 +128,7 @@ void MoeDispatchKernel(
hidden_size, moe_topk, stream);
}
} else {
initialize_moe_routing_kernelLauncher(
initialize_moe_routing_kernelLauncher<data_t>::run(
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), nullptr,
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
@@ -188,15 +185,6 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto expert_idx_per_token =
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
if (token_rows == 0){
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
topk_weight,
topk_idx,
expert_idx_per_token};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(

View File

@@ -35,8 +35,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const std::string& quant_method,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size) {
const int estimate_total_token_nums) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
@@ -292,7 +291,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
stream
);
@@ -342,7 +340,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream
);
@@ -406,15 +403,13 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size) {
const int estimate_total_token_nums) {
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
if(permute_input.numel() == 0){
return ffn_out;
}
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
@@ -429,8 +424,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size);
estimate_total_token_nums);
break;
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
@@ -445,8 +439,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size);
estimate_total_token_nums);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
@@ -465,8 +458,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size) {
const int estimate_total_token_nums) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
@@ -478,8 +470,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
expert_idx_per_token,
quant_method,
used_in_ep_low_latency,
estimate_total_token_nums,
hadamard_block_size)};
estimate_total_token_nums)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
@@ -494,8 +485,7 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size) {
const int estimate_total_token_nums) {
return {permute_input_shape};
}
@@ -509,7 +499,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums, const int hadamard_block_size) {
const int estimate_total_token_nums) {
if (quant_method == "w4a8" || quant_method == "w4afp8") {
return {up_gate_proj_scale_dtype.get()};
} else {
@@ -565,8 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
* - used_in_ep_low_latency: Whether running in low latency mode
* Affects activation function implementation
* - estimate_total_token_nums: estimate total token numbers
* - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization
*
* Note:
* - w4a8 mode requires additional workspace memory allocation
@@ -583,7 +571,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
typedef typename traits_::data_t data_t;
auto stream = ffn_out.stream();
finalize_moe_routing_kernelLauncher(
finalize_moe_routing_kernelLauncher<data_t>::run(
ffn_out.data<data_t>(), output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
@@ -59,10 +59,6 @@ paddle::Tensor MoeExpertReduceFunc(
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
if(num_rows == 0){
return output;
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(

View File

@@ -22,18 +22,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -59,12 +64,9 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
typedef PDTraits<D> traits_;
typedef typename traits_::data_t data_t;
// NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage
// int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
//
const bool mla_use_tensorcore = get_mla_use_tensorcore();
auto sm_version = GetSMVersion();
@@ -94,6 +96,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
out_linear_smooths,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
batch_id_per_token,
block_tables,
@@ -101,8 +104,9 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
decoder_tile_ids_per_batch,
decoder_num_blocks,
cache_quant_type_str,
decoder_chunk_size_device,
decoder_num_blocks_data,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
@@ -141,18 +145,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -199,18 +208,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
batch_id_per_token,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_chunk_size_device,
decoder_num_blocks_cpu,
max_enc_len_this_time,
max_dec_len_this_time,
max_len_kv,
attn_mask,
@@ -240,18 +254,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
batch_id_per_token,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_chunk_size_device,
decoder_num_blocks_cpu,
max_enc_len_this_time,
max_dec_len_this_time,
max_len_kv,
attn_mask,
@@ -288,18 +307,23 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& query_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& decoder_chunk_size_device_shape,
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
const std::vector<int64_t>& max_enc_len_this_time_shape,
const std::vector<int64_t>& max_dec_len_this_time_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
@@ -337,18 +361,23 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& query_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& decoder_chunk_size_device_dtype,
const paddle::DataType& decoder_num_blocks_cpu_dtype,
const paddle::DataType& max_enc_len_this_time_dtype,
const paddle::DataType& max_dec_len_this_time_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
@@ -386,18 +415,23 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention)
.Inputs({"query",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seqlens_q",
"batch_id_per_token",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"decoder_chunk_size_device",
"decoder_num_blocks_cpu",
"max_enc_len_this_time",
"max_dec_len_this_time",
"max_len_kv",
paddle::Optional("attn_mask"),

View File

@@ -15,72 +15,31 @@
#include "helper.h"
__global__ void recover_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
int thread_idx = threadIdx.x;
if (thread_idx < bsz) {
if(is_block_step[thread_idx] == true) {
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
// can be recovered for decoding
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx]= 1;
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
}
}
}
}
__global__ void recover_spec_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
int64_t *draft_tokens,
const int64_t *step_draft_tokens,
const int *step_seq_lens_this_time,
const int bsz,
const int block_num_per_seq,
const int block_size,
const int draft_tokens_len,
const int num_extra_tokens) {
int thread_idx = threadIdx.x;
if (thread_idx < bsz) {
if(is_block_step[thread_idx] == true) {
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
if (block_table_now[max_possible_block_idx] != -1) {
// can be recovered for decoding
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
draft_tokens_now[i] = step_draft_tokens_now[i];
// can be recovered for decoding
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx]= 1;
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
}
}
}
}
}
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
@@ -88,11 +47,7 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens) {
const int block_size) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
auto cu_stream = dev_ctx->stream();
@@ -101,38 +56,17 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
#endif
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
if (draft_tokens) {
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
step_draft_tokens.get_ptr()->data<int64_t>(),
step_seq_lens_this_time.get_ptr()->data<int>(),
bsz,
block_num_per_seq,
block_size,
draft_tokens_len,
max_draft_tokens * 2 + 1);
} else {
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
bsz,
block_num_per_seq,
block_size);
}
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
bsz,
block_num_per_seq,
block_size);
}
PD_BUILD_STATIC_OP(recover_decode_task)
@@ -142,11 +76,8 @@ PD_BUILD_STATIC_OP(recover_decode_task)
"seq_lens_decoder",
"step_seq_lens_decoder",
"block_tables",
"is_block_step",
paddle::Optional("draft_tokens"),
paddle::Optional("step_draft_tokens"),
paddle::Optional("step_seq_lens_this_time")})
.Attrs({"block_size: int", "max_draft_tokens: int"})
"is_block_step"})
.Attrs({"block_size: int"})
.Outputs({"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",

View File

@@ -75,7 +75,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
// 2 and -2 is perserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}

View File

@@ -45,7 +45,7 @@ void save_kernel(const paddle::Tensor& x,
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
// 2 and -2 is perserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}

View File

@@ -34,7 +34,7 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags,
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
if (step_idx[tid] >= 0) {
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
pre_ids_all_now[step_idx[tid]] = input_ids_now[seq_len_enc - 1];

Some files were not shown because too many files have changed in this diff Show More