Compare commits

..

23 Commits

Author SHA1 Message Date
Zero Rains
bd30b08521 get org_vocab_size from args (#3981) 2025-09-09 15:08:47 +08:00
Divano
1aa16146ba Update requirements.txt (#3915) 2025-09-05 13:51:22 +08:00
ApplEOFDiscord
dac0a00d0f [BugFix] fix max streaming tokens invalid (#3774) (#3856)
* Update serving_chat.py

* Update serving_completion.py

Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
2025-09-03 17:50:29 +08:00
ltd0924
c5591c45df [BugFix] fix max streaming tokens invalid (#3774)
* Update serving_chat.py

* Update serving_completion.py
2025-09-02 21:00:29 +08:00
chen
121ac85d7d fix (#3640) 2025-08-27 14:23:38 +08:00
chen
d233e3c97c [Precision] Change lm_head layer running in float32 (#3596)
* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
2025-08-26 20:20:06 +08:00
chen
2136990144 [Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3536)
* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing

* infer engine support temp_scaled_logprobs and top_p_normalized_logprobs

* code check

* code check

* fix tokenizer.decoder(-1), return 'Invalid Token'

* check seq len time shape

* logprob clip inf

* code check

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
2025-08-25 14:11:18 +08:00
kevin
b7890cbe8d fix uvicorn multi worker error (#3339) 2025-08-25 11:24:07 +08:00
chenjian
bc388b65c7 [Bug fix] Fix bug in logprob in release 2.0.4 (#3445)
* fix bug for scheduler v0

* Fix logprob in release/2.0.4
2025-08-16 21:13:10 +08:00
Jiang-Jia-Jun
71af0ca04a [BugFix] Fix default log level of paddleformers (#3378) 2025-08-15 18:30:00 +08:00
YuBaoku
d66660a0d1 [CI] fix run_ci error in release/2.0.4 (#3411) 2025-08-14 22:44:17 +08:00
xiaolei373
f0519aec67 feat(log):add_request_and_response_log (#3391)
* feat(log):add_request_and_response_log

* [ci] Retrigger

* [ci] Retrigger
2025-08-14 19:12:42 +08:00
gaoziyuan
1f5983290c fix mapping (#3321) 2025-08-12 16:17:59 +08:00
chenjian
c6a133d573 [Bug fix] Fix block num in scheduler v1 for release2.0.4 (#3314)
* fix bug for scheduler v0

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1
2025-08-11 23:55:45 +08:00
chenjian
4646aff25c fix bug for scheduler v0 (#3307) 2025-08-11 23:55:20 +08:00
chenjian
a84a98b107 fix scheduler bug due to async running (#3293) 2025-08-10 13:54:59 +08:00
chenjian
c208086f61 fix scheduler bug for bs=1 (#3288) 2025-08-09 12:22:12 +08:00
sg263
ce1d4944e7 merge develop trace FD_START (#3253) (#3260)
Co-authored-by: shige <shige@baidu.com>
2025-08-07 16:06:58 +08:00
chenjian
5439fb6336 [Cherry-pick] FIx bug for scheduler V1 (#3167)
* [BUG FIX] Fix bug when preempted request rescheduled (#3080)

* Fix bug when preempted request rescheduled

* Fix bug when preempted request rescheduled

* Fix bug when preempted request rescheduled

* Fix bug for offline inference in scheduler v1 (#3117)
2025-08-04 17:08:12 +08:00
gaoziyuan
a592d17615 support qwen3 name_mapping (#3180) 2025-08-04 16:37:34 +08:00
李泳桦
eca8fc7ca6 [feat] extra parameters are all passed directly via http payload now, or in extra_body if using openai client (#3077)
* [feat] extra parameters are all passed directly via http payload now, or in extra_body if using openai client

* [fix] delete ci test case for enable_thinking

* [fix] add reasoning_parser when server starts

* [doc] update docs related to metadata

* [fix] fix ci consistency test error with reasoning parser

* [fix] cancel enable_thinking default value
2025-07-30 19:25:39 +08:00
李泳桦
0463797fc2 [feat] add disable_chat_template in chat api as a substitute for previous raw_request (#3023)
* [feat] add disable_chat_template in chat api as a substitute for previous raw_request

* [fix] pre-commit code check
2025-07-25 20:57:06 +08:00
Jiang-Jia-Jun
0ab8645fc4 Update setup.py 2025-07-25 10:27:51 +08:00
824 changed files with 14747 additions and 91986 deletions

View File

@@ -2,9 +2,7 @@ name: Codestyle-Check
on:
pull_request:
branches:
- develop
- 'release/*'
branches: ["develop"]
jobs:
pre-commit:
@@ -13,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
env:
PR_ID: ${{ github.event.pull_request.number }}
BRANCH: ${{ github.event.pull_request.base.ref }}
BRANCH: develop
steps:
- name: Cleanup

View File

@@ -1,186 +0,0 @@
name: Accuracy Test
description: "Run Accuracy Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
accuracy_tests:
runs-on: [self-hosted, GPU-h20-1Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Base Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
popd
pushd tests/ce/accuracy_cases
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
export TEMPLATE=TOKEN_LOGPROB
export MODEL_SIZE=0.3B
TEST_EXIT_CODE=0
python gsm8k.py || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -1,229 +0,0 @@
name: Base Test
description: "Run Base Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
base_tests:
runs-on: [self-hosted, GPU-h20-1Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Base Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
check_service() {
local timeout=${1:-90}
local url="http://localhost:${FLASK_PORT}/wait_for_infer?timeout=${timeout}"
local resp
resp=$(curl -s -X POST "$url")
if echo "$resp" | grep -q "服务启动超时"; then
exit 8
fi
}
check_service 90
popd
pushd tests/ce/server
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
export TEMPLATE=TOKEN_LOGPROB
TEST_EXIT_CODE=0
python -m pytest -sv test_base_chat.py test_compare_top_logprobs.py test_logprobs.py test_params_boundary.py test_seed_usage.py test_stream.py test_evil_cases.py test_completions.py test_return_token_ids.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--early-stop-config\": \"{\\\"enable_early_stop\\\":true, \\\"window_size\\\":6, \\\"threshold\\\":0.93}\"}"
check_service 90
python -m pytest -sv test_repetition_early_stop.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }"
check_service 90
python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }"
check_service 90
python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_mtp.yaml\", \"--enable-logprob\": \"False\"}"
check_service 180
export TEMPLATE=TOKEN_NORMAL
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_sot.yaml\", \"--enable-logprob\": \"False\"}"
check_service 360
export TEMPLATE=TOKEN_NORMAL
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -22,22 +22,12 @@ on:
description: "Enable nightly build mode (e.g. add date suffix to version)"
required: false
type: string
default: "OFF"
default: "ON"
FD_VERSION:
description: "FastDeploy Package Version"
required: false
type: string
default: ""
PADDLEVERSION:
description: "Paddle Version Build Use"
required: false
type: string
default: ""
PADDLE_WHL_URL:
description: "Paddle Wheel Package URL"
required: false
type: string
default: ""
UPLOAD:
description: "Upload Package"
required: false
@@ -54,8 +44,7 @@ on:
value: ${{ jobs.fd-build.outputs.wheel_path }}
jobs:
fd-build:
runs-on: [self-hosted, GPU-Build]
timeout-minutes: 240
runs-on: [self-hosted, GPU-h1z1-4Cards]
outputs:
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
steps:
@@ -96,17 +85,13 @@ jobs:
compile_arch: ${{ inputs.COMPILE_ARCH }}
fd_version: ${{ inputs.FD_VERSION }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
BRANCH_REF: ${{ github.ref_name }}
PADDLEVERSION: ${{ inputs.PADDLEVERSION }}
PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }}
WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }}
run: |
set -x
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
CARD_ID=$(echo "${runner_name}" | cut -d'-' -f2)
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
CACHE_DIR=${CACHE_DIR:-${{ github.workspace }}}
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
@@ -118,15 +103,11 @@ jobs:
-v $(pwd):/workspace -w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/.ccache:/root/.ccache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
-e "COMPILE_ARCH=${compile_arch}" \
-e "FD_VERSION=${fd_version}" \
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
-e "PADDLEVERSION=${PADDLEVERSION}" \
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
-e "BRANCH_REF=${BRANCH_REF}" \
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
if [[ -n "${FD_VERSION}" ]]; then
export FASTDEPLOY_VERSION=${FD_VERSION}
@@ -142,20 +123,14 @@ jobs:
echo "Date Only: $DATE_ONLY"
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
fi
# 针对不同分支和tag使用不同的PaddlePaddle安装包
if [[ "${PADDLE_WHL_URL}" != "" ]];then
python -m pip install ${PADDLE_WHL_URL}
elif [[ "${PADDLEVERSION}" != "" ]];then
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
else
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
fi
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
pip config set install.trusted-host pip.baidu.com
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install wheel
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
# 编译RDMA
export ENABLE_FD_RDMA=1
bash build.sh 1 python false [${COMPILE_ARCH}]

View File

@@ -1,73 +0,0 @@
name: Docker Build
description: "FastDeploy CI Image Build"
on:
workflow_call:
inputs:
CI_DOCKER_IMAGE_NAME:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
DOCKER_IMAGE_NAME:
description: "Build Images"
required: false
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
outputs:
docker_name_precheck:
description: "Output path of the generated wheel"
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
jobs:
docker_build:
runs-on: [self-hosted, Docker-Build]
outputs:
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
steps:
- name: Code Prepare
id: docker_build
shell: bash
env:
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
# Docker Build
cd tools/dockerfile/
set -e
cp ../../requirements.txt ./
cp ../../scripts/unittest_requirement.txt ./
docker build -t ${docker_image_name} -f Dockerfile.ci . \
--network host \
--no-cache
docker push ${docker_image_name}
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT

View File

@@ -68,7 +68,7 @@ jobs:
branch_name=${{ github.ref_name }}
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -O bos_tools.py -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls

View File

@@ -1,184 +0,0 @@
name: Run FastDeploy LogProb Tests
description: "Run FastDeploy LogProb Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
PADDLETEST_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
default: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
run_tests_logprob:
runs-on: [self-hosted, GPU-h20-1Cards]
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
run: |
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
rm -rf /workspace/*
'
wget -q ${paddletest_archive_url}
tar -xf PaddleTest.tar.gz
rm -rf PaddleTest.tar.gz
cd PaddleTest
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: logprob test
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
cd PaddleTest/framework/ServeTest
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
curl -s -o /dev/null -w "%{http_code}" -m 2 "http://0.0.0.0:${FD_API_PORT}/health"
curl -X POST "http://0.0.0.0:${FD_API_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}"
set +e
rm -rf ./baseline_output
cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output
LOGPROB_EXIT_CODE=0
python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$?
echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env
curl -X POST http://localhost:${FLASK_PORT}/stop
sleep 10s
cat *result.log
exit 0
'
if [ $? -ne 0 ];then
exit 1
fi
if [ -f exit_code.env ]; then
cat exit_code.env >> $GITHUB_ENV
fi
- name: logprob test result
if: ${{ env.LOGPROB_EXIT_CODE != 0 }}
shell: bash
run: |
echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}"
exit 8

View File

@@ -1,148 +0,0 @@
name: Pre-CE-Test
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
run_ce_cases:
runs-on: [self-hosted, PRE_CE_RUN_2Card]
timeout-minutes: 60
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run CI unittest
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "fd_wheel_url=${fd_wheel_url}" \
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
python -m pip install ${fd_wheel_url}
bash scripts/run_pre_ce.sh
'

View File

@@ -1,170 +0,0 @@
name: Stable Test
description: "Run Stable Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
stable_tests:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Stable Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42038 + DEVICE_PORT * 100))
FD_INFERENCE_MSG_QUEUE_ID=$(( 42048 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
TEST_EXIT_CODE=0
pushd tests/ce/stable_cases
bash launch_model.sh /MODELDATA
bash run.sh || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -1,319 +0,0 @@
name: Coverage Check
description: "Run FastDeploy Unit Tests and Coverage"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
secrets:
github-token:
required: true
jobs:
check_cov_skip:
uses: ./.github/workflows/check-bypass.yml
secrets:
github-token: ${{ secrets.github-token }}
with:
workflow-name: coverage
run_tests_with_coverage:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 60
needs: check_cov_skip
if: needs.check_cov_skip.outputs.can-skip != 'true'
outputs:
diff_cov_file_url: ${{ steps.cov_upload.outputs.diff_cov_file_url }}
unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }}
diff_cov_result_json_url: ${{ steps.cov_upload.outputs.diff_cov_result_json_url }}
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Unit Tests and Coverage
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
BASE_REF: ${{ github.event.pull_request.base.ref }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
IS_PR: ${{ github.event_name == 'pull_request' }}
run: |
if [[ "$IS_PR" == "true" ]]; then
echo "Running on PR"
else
echo "Not a PR"
fi
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --net=host \
--name ${runner_name} \
--cap-add=SYS_PTRACE --shm-size=64G \
-v $(pwd):/workspace -w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e TZ="Asia/Shanghai" \
-e "fd_wheel_url=${fd_wheel_url}" \
-e "BASE_REF=${BASE_REF}" \
-e "IS_PR=${IS_PR}" \
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install -r scripts/unittest_requirement.txt
python -m pip install ${fd_wheel_url}
rm -rf fastdeploy
# coverage subprocess use
python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy
export PYTHONPATH=/workspace/FastDeploy/
if [ -d "tests/plugins" ]; then
cd tests/plugins
python setup.py install
cd ../..
else
echo "Warning: tests/plugins directory not found, skipping setup.py install"
fi
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
TEST_EXIT_CODE=0
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
coverage combine coveragedata/ || echo "No data to combine"
coverage report
coverage xml -o python_coverage_all.xml
COVERAGE_EXIT_CODE=0
if [[ "$IS_PR" == "true" ]]; then
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
else
echo "Not a PR, skipping diff-cover"
fi
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
'
if [ -f FastDeploy/exit_code.env ]; then
cat FastDeploy/exit_code.env >> $GITHUB_ENV
fi
- name: Upload unit resule and diff coverage to bos
id: cov_upload
shell: bash
run: |
cd FastDeploy
commit_id=${{ github.event.pull_request.head.sha }}
pr_num=${{ github.event.pull_request.number }}
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py -O bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
diff_cov_file="diff_coverage.xml"
if [ -f ${diff_cov_file} ];then
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
target_path_stripped="${target_path#paddle-github-action/}"
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
fi
diff_cov_result_json="diff_coverage.json"
if [ -f ${diff_cov_result_json} ];then
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
target_path_stripped="${target_path#paddle-github-action/}"
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
fi
unittest_result="failed_tests.log"
if [ -s ${unittest_result} ];then
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
target_path_stripped="${target_path#paddle-github-action/}"
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
fi
- name: Check Unit Test Success
shell: bash
run: |
cd FastDeploy
if [ "$TEST_EXIT_CODE" -eq 8 ]; then
filename=$(basename "$unittest_failed_url")
if [ -z "${unittest_failed_url}" ]; then
echo "No diff unit failed file URL provided."
else
rm -rf "${filename}"
wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
fi
echo "Unit tests failed (exit code 8)"
if [ -f "${filename}" ];then
echo "Failed test cases:"
cat "${filename}"
fi
exit "$TEST_EXIT_CODE"
fi
echo "All tests passed"
- name: Verify Code Coverage Threshold (80%)
if: ${{ github.event_name == 'pull_request' }}
shell: bash
run: |
cd FastDeploy
if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then
echo "Coverage generation failed (exit code 9)"
filename=$(basename "$diff_cov_result_json_url")
if [ -z "${diff_cov_result_json_url}" ]; then
echo "No diff cov result file URL provided."
else
rm -rf "${filename}"
wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
fi
if [ -f "${filename}" ];then
echo "Failed test cases:"
if command -v jq >/dev/null 2>&1; then
jq . "${filename}"
else
cat "${filename}"
fi
fi
exit "$COVERAGE_EXIT_CODE"
fi
echo "coverage passed"
exit 0
diff_coverage_report:
needs: run_tests_with_coverage
if: always()
runs-on: ubuntu-latest
env:
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
steps:
- name: coverage diff file download
shell: bash
env:
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
run: |
wget ${fd_archive_url}
tar -xf FastDeploy.tar.gz
cd FastDeploy
if [ -z "${diff_cov_file_url}" ]; then
echo "No diff coverage file URL provided."
exit 0
fi
wget "${diff_cov_file_url}" -O ./diff_coverage.xml || echo "Download cov file failed, but continuing..."
- name: Upload diff coverage report
if: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url != null && needs.run_tests_with_coverage.outputs.diff_cov_file_url != '' }}
uses: codecov/codecov-action@v5
with:
files: ./FastDeploy/diff_coverage.xml
name: python diff coverage
verbose: true
disable_search: true
commit_parent: false
flags: diff

View File

@@ -1,42 +0,0 @@
name: Approval
on:
pull_request:
branches:
- develop
- 'release/*'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:
Approval:
name: Approval
if: ${{ github.repository_owner == 'PaddlePaddle' }}
runs-on: ubuntu-latest
env:
PR_ID: ${{ github.event.pull_request.number }}
BRANCH: ${{ github.event.pull_request.base.ref }}
steps:
- name: Checkout base repo
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.base.ref }}
fetch-depth: 1000
- name: Merge PR to test branch
run: |
git fetch origin pull/${PR_ID}/merge
git checkout -b test FETCH_HEAD
git log -n 3 --oneline
git remote add upstream https://github.com/PaddlePaddle/FastDeploy.git
git fetch upstream $BRANCH
- name: Setup python3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run approval check script
run: |
bash scripts/check_approval.sh

View File

@@ -1,248 +0,0 @@
name: CE Compile Job
on:
workflow_dispatch:
push:
branches:
- develop
- 'release/*'
permissions: read-all
concurrency:
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
ce_job_pre_check:
runs-on: ubuntu-latest
env:
COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }}
CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
outputs:
branch_match: ${{ steps.set_output.outputs.branch_match }}
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
sm8689_match: ${{ steps.set_output.outputs.sm8689_match }}
sm8090_match: ${{ steps.set_output.outputs.sm8090_match }}
steps:
- name: Set Version
id: set_output
env:
COMPILE_BRANCH: ${{ env.COMPILE_BRANCH }}
CE_COMPILE_SELECTION: ${{ env.CE_COMPILE_SELECTION }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ env.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
GITHUB_REF_NAME: ${{ github.ref_name }}
run: |
# 选择要触发编译任务的分支 done
# 选择指定分支要编译的任务 8090或者8689
# 指定分支编译要使用的Paddle的安装包,默认使用nightly最新的
IFS=',' read -ra BRANCHES <<< "$COMPILE_BRANCH"
MATCH=false
for b in "${BRANCHES[@]}"; do
if [[ "$b" == "${GITHUB_REF_NAME}" ]]; then
MATCH=true
break
fi
done
echo "branch_match=$MATCH" >> $GITHUB_OUTPUT
# 通过变量CE_COMPILE_SELECTION中的映射关系,决定分支是编译sm8090还是sm8689
for pair in $(echo "$CE_COMPILE_SELECTION" | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
compile_task_list=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "$GITHUB_REF_NAME" ]]; then
# 判断里面是否包含 sm8090 或 sm8689
if [[ "$compile_task_list" == *"sm8090"* ]]; then
echo "sm8090_match=true" >> $GITHUB_OUTPUT
fi
if [[ "$compile_task_list" == *"sm8689"* ]]; then
echo "sm8689_match=true" >> $GITHUB_OUTPUT
fi
break
fi
done
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
FOUND_PADDLE_URL="$paddle_whl_url"
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
break
fi
done
print_ce_job_pre_check_outputs:
runs-on: ubuntu-latest
needs: ce_job_pre_check
steps:
- name: Print outputs as JSON
run: |
echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}'
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
needs: ce_job_pre_check
if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }}
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request'
&& github.event.pull_request.base.ref
|| github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
build_sm8090:
name: BUILD_SM8090
needs: [clone, ce_job_pre_check]
if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }}
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "80,90"
WITH_NIGHTLY_BUILD: OFF
FD_VERSION: 0.0.0
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
build_sm8689:
name: BUILD_SM8689
needs: [clone, ce_job_pre_check]
if: ${{ needs.ce_job_pre_check.outputs.sm8689_match == 'true' }}
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "86,89"
WITH_NIGHTLY_BUILD: OFF
FD_VERSION: 0.0.0
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
ce_upload_sm8090:
environment: CodeSync
name: CE_UPLOAD
needs: build_sm8090
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
COMPILE_ARCH: "80,90"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
echo "commit wheel url is ${WHEEL_PATH}"
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
ce_upload_sm8689:
environment: CodeSync
name: CE_UPLOAD
needs: build_sm8689
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
COMPILE_ARCH: "86,89"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
echo "commit wheel url is ${WHEEL_PATH}"
echo "latest wheel url is ${WHEEL_PATH_LATEST}"

View File

@@ -1,51 +0,0 @@
on:
workflow_call:
inputs:
workflow-name:
required: true
type: string
secrets:
github-token:
required: true
outputs:
can-skip:
description: "Whether the workflow can be skipped."
value: ${{ jobs.check-bypass.outputs.can-skip }}
jobs:
check-bypass:
name: Check bypass
runs-on: ubuntu-latest
permissions:
contents: read
env:
CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
outputs:
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
steps:
- name: Cleanup
run: |
rm -rf * .[^.]*
- id: check-bypass
name: Check Bypass
uses: PFCCLab/ci-bypass@v1
with:
github-token: ${{ secrets.github-token }}
non-pull-request-event-strategy: 'never-skipped'
type: 'composite'
composite-rule: |
{
"any": [
{
"type": "labeled",
"label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"],
"username": ${{ env.CI_TEAM_MEMBERS }}
},
{
"type": "commented",
"comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"],
"username": ${{ env.CI_TEAM_MEMBERS }}
}
]
}

109
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,109 @@
name: CI
on:
pull_request:
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
build:
runs-on: [self-hosted, GPU-L20-4Card]
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
# Because the system version is lower than 2.23, the checkout cannot be used.
# - name: Checkout code
# uses: actions/checkout@v4
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}
fi
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
git merge pr/${{ github.event.pull_request.number }}
git log -n 3 --oneline
else
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
if [ "${last_char}" = "1" ]; then
gpu_id=2
DEVICES="2,3"
else
gpu_id=0
DEVICES="0,1"
fi
FLASK_PORT=$((41068 + gpu_id * 100))
FD_API_PORT=$((41088 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((41058 + gpu_id * 100))
FD_METRICS_PORT=$((41078 + gpu_id * 100))
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \
-v "/ssd4/GithubActions/ModelData:/ModelData:ro" \
-v "/ssd4/GithubActions/CacheDir:/root/.cache" \
-v "/ssd4/GithubActions/ConfigDir:/root/.config" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci.sh
"

View File

@@ -1,98 +0,0 @@
name: CI_GCU
on:
pull_request:
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}-gcu-ci
cancel-in-progress: true
jobs:
CI_GCU:
runs-on:
group: GCU
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace \
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
-w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}
fi
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
source ${{ github.workspace }}/../../../proxy
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
git merge pr/${{ github.event.pull_request.number }}
git log -n 3 --oneline
else
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
echo "Copy models..."
sudo mkdir -p ci_models && sudo cp -r /work/deps/ERNIE-4.5-21B-A3B-Paddle ci_models
echo "Copy models done."
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
if [[ "$last_char" =~ [0-3] ]]; then
gcu_id="$last_char"
else
gcu_id="0"
fi
FD_API_PORT=$((9180 + gcu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gcu_id * 100))
FD_METRICS_PORT=$((9170 + gcu_id * 100))
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
echo "Install drivers..."
cd /work/deps
sudo bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
cd -
echo "Create docker..."
docker run --rm --network=host --ipc=host --privileged \
-v $(pwd):/workspace \
-v /home:/home \
-v /work:/work \
-w /workspace \
-e "MODEL_PATH=./ci_models" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci_gcu.sh
"

View File

@@ -11,8 +11,7 @@ concurrency:
jobs:
CI_ILUVATAR:
runs-on:
group: IXUCA
runs-on: [self-hosted, IXUCA]
steps:
- name: Print current runner name
run: |

View File

@@ -1,174 +0,0 @@
name: CI Images Build
on:
workflow_dispatch:
schedule:
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
permissions: read-all
concurrency:
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
if [[ "${{ github.ref_type }}" == "tag" ]]; then
commit_id=${{ github.sha }}
tag_name=${{ github.ref_name }}
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
else
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
ci_image_build:
name: CI Images Build
needs: clone
uses: ./.github/workflows/_ci_image_build.yml
with:
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
build_sm8090:
name: BUILD_SM8090
needs: [clone, ci_image_build]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "90"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build_sm8090,ci_image_build]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
publish_pre_check:
name: Publish Docker Images Pre Check
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
runs-on: [self-hosted, Docker-Build]
steps:
- name: Images Uploading
env:
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
run: |
echo "images_name=${images_name}"
docker images ${ci_image_name}
docker tag ${images_name} ${ci_image_name}
docker push ${ci_image_name}

View File

@@ -24,7 +24,7 @@ jobs:
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
@@ -55,7 +55,7 @@ jobs:
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
@@ -77,7 +77,6 @@ jobs:
-e "MODEL_PATH=/ssd3/model" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \

View File

@@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang
- name: Deploy to GitHub Pages
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -19,7 +19,7 @@ jobs:
needs: clone
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "90"
WITH_NIGHTLY_BUILD: "OFF"
@@ -33,65 +33,3 @@ jobs:
- name: Print wheel path
run: |
echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}"
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

View File

@@ -1,331 +0,0 @@
name: Publish Job
on:
workflow_dispatch:
schedule:
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
push:
# branches:
# - develop
tags:
- '*'
permissions: read-all
concurrency:
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
publish_pre_check:
runs-on: ubuntu-latest
if: |
github.event.repository.fork == false &&
(
(github.event_name == 'schedule' && github.ref_name == 'develop') ||
(github.event_name == 'push' && github.ref_type == 'tag') ||
((github.event_name == 'workflow_dispatch') &&
(github.ref_name == 'develop' || github.ref_type == 'tag'))
)
env:
TAG_VERSION_MAPPINGS: ${{ vars.TAG_VERSION_MAPPINGS }}
FD_VERSION_DEV: ${{ vars.FD_VERSION_DEV }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
outputs:
compile_use_paddle_version: ${{ steps.set_output.outputs.compile_use_paddle_version }}
compile_continue: ${{ steps.set_output.outputs.compile_continue }}
fd_version: ${{ steps.set_output.outputs.fd_version }}
with_nightly_build: ${{ steps.set_output.outputs.with_nightly_build }}
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
steps:
- name: Get tag version
if: github.ref_type == 'tag'
run: |
TAG_NAME="${GITHUB_REF##*/}" # 提取 tag 名称,比如 v2.1.0
TAG_VERSION="${TAG_NAME#v}" # 去掉前缀 v
echo "FD_VERSION=$TAG_VERSION" >> $GITHUB_ENV
- name: Check FD version to Paddle version mapping
if: github.ref_type == 'tag'
env:
TARGET_FD: ${{ env.FD_VERSION }}
run: |
FOUND_PADDLE=""
# 遍历映射
for pair in $(echo $TAG_VERSION_MAPPINGS | tr ';' ' '); do
fd=$(echo "$pair" | cut -d',' -f1)
paddle=$(echo "$pair" | cut -d',' -f2)
if [[ "$fd" == "$TARGET_FD" ]]; then
FOUND_PADDLE="$paddle"
break
fi
done
if [[ -z "$FOUND_PADDLE" ]]; then
echo "No Paddle version found for FD $TARGET_FD"
else
echo "FD $TARGET_FD maps to Paddle $FOUND_PADDLE"
echo "PADDLE_VERSION=$FOUND_PADDLE" >> $GITHUB_ENV
fi
- name: Set Version
id: set_output
env:
PADDLE_VERSION: ${{ env.PADDLE_VERSION }}
FD_VERSION: ${{ env.FD_VERSION }}
run: |
if [[ "${{ github.ref_type }}" == "tag" ]]; then
if [[ -z "$PADDLE_VERSION" ]]; then
compile_continue=false
else
compile_use_paddle_version=$PADDLE_VERSION
compile_continue=true
fi
fd_version=$FD_VERSION
fi
if [[ "${{ github.ref_name }}" == "develop" ]];then
compile_continue=true
compile_use_paddle_version=""
fd_version=${FD_VERSION_DEV}
with_nightly_build=ON
fi
# Todo
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
FOUND_PADDLE_URL="$paddle_whl_url"
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
compile_continue=true
break
fi
done
echo "compile_continue=${compile_continue}" >> $GITHUB_OUTPUT
echo "compile_use_paddle_version=${compile_use_paddle_version}" >> $GITHUB_OUTPUT
echo "fd_version=${fd_version}" >> $GITHUB_OUTPUT
echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT
print_publish_pre_check_outputs:
runs-on: ubuntu-latest
needs: publish_pre_check
steps:
- name: Print outputs as JSON
run: |
echo '${{ toJSON(needs.publish_pre_check.outputs) }}'
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
needs: publish_pre_check
if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }}
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
if [[ "${{ github.ref_type }}" == "tag" ]]; then
commit_id=${{ github.sha }}
tag_name=${{ github.ref_name }}
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
else
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
build_sm8090:
name: BUILD_SM8090
needs: [clone, publish_pre_check]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "80,90"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
build_sm8689:
name: BUILD_SM8689
needs: [clone, publish_pre_check]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "86,89"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
paddle_pypi_upload_sm8090:
environment: PaddleSourceUpload
name: PADDLE_PYPI_UPLOAD_8090
needs: build_sm8090
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
COMPILE_ARCH: "80,90"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
if: github.ref_name == 'develop' || github.ref_type == 'tag'
run: |
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
if [[ "${{ github.ref_name }}" == "develop" ]];then
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
else
echo "Not develop or tag, do nothing"
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
paddle_pypi_upload_sm8689:
environment: PaddleSourceUpload
name: PADDLE_PYPI_UPLOAD_8689
needs: build_sm8689
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
COMPILE_ARCH: "86,89"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
if: github.ref_name == 'develop' || github.ref_type == 'tag'
run: |
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
if [[ "${{ github.ref_name }}" == "develop" ]];then
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
else
echo "Not develop or tag, do nothing"
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build_sm8090]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build_sm8090]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build_sm8090]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

14
.gitignore vendored
View File

@@ -121,7 +121,7 @@ dmypy.json
FETCH_HEAD
#log
log/
log*/
checkpoints/
checkpoints_origin/
@@ -156,12 +156,6 @@ nohup.out
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
#marlin_kernel
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
#machete_kernel
custom_ops/gpu_ops/machete/generated
# buff
custom_ops/tmp*
@@ -170,9 +164,3 @@ build
.ccls-cache
third_party
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h

View File

@@ -1,4 +1,3 @@
English | [简体中文](README_CN.md)
<p align="center">
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
</p>
@@ -23,12 +22,11 @@ English | [简体中文](README_CN.md)
</p>
--------------------------------------------------------------------------------
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
## News
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -52,16 +50,14 @@ English | [简体中文](README_CN.md)
## Installation
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions:
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md)
- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md)
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
## Get Started
@@ -71,12 +67,19 @@ Learn how to use FastDeploy through our documentation:
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
- [Offline Inference Development](./docs/offline_inference.md)
- [Online Service Deployment](./docs/online_serving/README.md)
- [Best Practices](./docs/best_practices/README.md)
- [Full Supported Models List](./docs/supported_models.md)
## Supported Models
Learn how to download models, enable using the torch format, and more:
- [Full Supported Models List](./docs/supported_models.md)
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K |
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
## Advanced Usage

View File

@@ -1,89 +0,0 @@
[English](README.md) | 简体中文
<p align="center">
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
</p>
<p align="center">
<a href=""><img src="https://img.shields.io/badge/python-3.10-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux-pink.svg"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/FastDeploy?color=9ea"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/"><b> 安装指导 </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/quick_start"><b> 快速入门 </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/zh/supported_models/"><b> 支持模型列表 </b></a>
</p>
--------------------------------------------------------------------------------
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
## 最新活动
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容性能进一步优化更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略更多模型支持PD分离和CUDA Graph昆仑、海光等更多硬件支持增强全方面优化服务和推理引擎的性能。
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
## 关于
**FastDeploy** 是基于飞桨PaddlePaddle的大语言模型LLM与视觉语言模型VLM推理部署工具包提供**开箱即用的生产级部署方案**,核心技术特性包括:
- 🚀 **负载均衡式PD分解**工业级解决方案支持上下文缓存与动态实例角色切换在保障SLO达标和吞吐量的同时优化资源利用率
- 🔄 **统一KV缓存传输**轻量级高性能传输库支持智能NVLink/RDMA选择
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
- 🧮 **全量化格式支持**W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
-**高级加速技术**推测解码、多令牌预测MTP及分块预填充
- 🖥️ **多硬件支持**NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
## 要求
- 操作系统: Linux
- Python: 3.10 ~ 3.12
## 安装
FastDeploy 支持在**英伟达NVIDIAGPU**、**昆仑芯KunlunxinXPU**、**天数IluvatarGPU**、**燧原EnflameGCU**、**海光HygonDCU** 以及其他硬件上进行推理部署。详细安装说明如下:
- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md)
- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md)
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
**注意:** 我们正在积极拓展硬件支持范围。目前包括昇腾AscendNPU 等其他硬件平台正在开发测试中。敬请关注更新!
## 入门指南
通过我们的文档了解如何使用 FastDeploy
- [10分钟快速部署](./docs/zh/get_started/quick_start.md)
- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md)
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
- [离线推理](./docs/zh/offline_inference.md)
- [在线服务](./docs/zh/online_serving/README.md)
- [最佳实践](./docs/zh/best_practices/README.md)
## 支持模型列表
通过我们的文档了解如何下载模型如何支持torch格式等
- [模型支持列表](./docs/zh/supported_models.md)
## 进阶用法
- [量化](./docs/zh/quantization/README.md)
- [分离式部署](./docs/zh/features/disaggregated.md)
- [投机解码](./docs/zh/features/speculative_decoding.md)
- [前缀缓存](./docs/zh/features/prefix_caching.md)
- [分块预填充](./docs/zh/features/chunked_prefill.md)
## 致谢
FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。

View File

@@ -361,7 +361,8 @@ async def benchmark(
if not test_output.success:
raise ValueError(
f"Initial test run failed - Please make sure that 1. benchmark arguments are correctly specified and 2. the http_proxy and https_proxy are turned off. Error: {test_output.error}"
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")

View File

@@ -1,6 +0,0 @@
num_gpu_blocks_override: 1024
max_model_len: 8192
max_num_seqs: 64
data_parallel_size: 8
tensor_parallel_size: 1
enable_expert_parallel: True

View File

@@ -1,8 +0,0 @@
top_p: 0.95
temperature: 0.6
metadata:
min_tokens: 1
max_tokens: 65535
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 0

View File

@@ -1,10 +0,0 @@
reasoning-parser: ernie_x1
tool_call_parser: ernie_x1
tensor_parallel_size: 4
max_model_len: 65536
max_num_seqs: 128
enable_prefix_caching: True
enable_chunked_prefill: True
gpu_memory_utilization: 0.85
use_cudagraph: True
enable_custom_all_reduce: True

View File

@@ -34,6 +34,7 @@ EGG_DIR="fastdeploy.egg-info"
# custom_ops directory config
OPS_SRC_DIR="custom_ops"
OPS_TMP_DIR_BASE="tmp_base"
OPS_TMP_DIR="tmp"
# command line log config
@@ -70,20 +71,25 @@ function copy_ops(){
PY_VERSION="py${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
SYSTEM_VERSION=`${python} -c "import platform; print(platform.system().lower())"`
PROCESSOR_VERSION=`${python} -c "import platform; print(platform.processor())"`
WHEEL_BASE_NAME="fastdeploy_base_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
WHEEL_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then
DEVICE_TYPE="rocm"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "ROCM ops have been copy to fastdeploy"
echo -e "BASE and ROCM ops have been copy to fastdeploy"
return
fi
mkdir -p ../fastdeploy/model_executor/ops/base
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
if [ "$is_cuda" = "True" ]; then
DEVICE_TYPE="gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "CUDA ops have been copy to fastdeploy"
echo -e "BASE and CUDA ops have been copy to fastdeploy"
return
fi
@@ -106,8 +112,9 @@ function copy_ops(){
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
if [ "$if_corex" = "True" ]; then
DEVICE_TYPE="iluvatar-gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
echo -e "Iluvatar ops have been copy to fastdeploy"
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
return
fi
@@ -119,26 +126,20 @@ function copy_ops(){
return
fi
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
if [ "$is_maca" = "True" ]; then
DEVICE_TYPE="metax_gpu"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "MACA ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../
cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu
echo -e "CPU ops have been copy to fastdeploy"
echo -e "BASE and CPU ops have been copy to fastdeploy"
return
}
function build_and_install_ops() {
cd $OPS_SRC_DIR
export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy}
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..."
${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE}
find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \;
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
@@ -212,6 +213,7 @@ function cleanup() {
fi
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR
}

View File

@@ -84,6 +84,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
seq_length,
bsz);
return {x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k};
@@ -96,7 +97,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -105,6 +106,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
@@ -113,6 +115,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,11 +19,10 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <typename T>
void RebuildPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cu_seqlens_q_data,
const int *cum_offsets_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
@@ -41,12 +40,11 @@ void RebuildPaddingCPUImpl(T *output_data,
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue;
}
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int ori_token_idx =
bi * max_input_length - cum_offsets_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
@@ -56,7 +54,7 @@ void RebuildPaddingCPUImpl(T *output_data,
template <typename T>
void RebuildAppendPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cu_seqlens_q_data,
const int *cum_offsets_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
@@ -71,32 +69,30 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
int bi = ori_token_id / max_input_length;
if (seq_len_this_time_data[bi] == 0 ||
(seq_lens_decoder_data[bi] == 0 &&
seq_lens_encoder_data[bi] == 0)) {
continue;
}
seq_lens_encoder_data[bi] == 0)) {
continue;
}
int seq_id = 0;
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
}
}
std::vector<paddle::Tensor> RebuildPaddingCPU(
const paddle::Tensor &tmp_out,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu =
@@ -111,7 +107,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
int bsz = cum_offsets_cpu.shape()[0];
paddle::Tensor out;
if (output_padding_offset_cpu) {
@@ -132,7 +128,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
}
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *cum_offsets_data = cum_offsets_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
@@ -145,7 +141,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -158,7 +154,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -171,7 +167,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -190,7 +186,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -202,7 +198,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -211,10 +207,11 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -233,7 +230,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cu_seqlens_q_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
@@ -242,14 +239,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
int64_t bsz = cum_offsets_shape[0];
return {{bsz, dim_embed}};
}
}
std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cu_seqlens_q_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
@@ -259,7 +256,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
.Inputs({"tmp_out",
"cu_seqlens_q",
"cum_offsets",
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",

View File

@@ -14,7 +14,7 @@
#include "paddle/extension.h"
void set_value_by_flags_and_idx(const bool *stop_flags,
void set_value_by_flag_and_id(const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
set_value_by_flags_and_idx(stop_flags.data<bool>(),
set_value_by_flag_and_id(stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),

View File

@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
not_need_stop[0] = stop_sum < stop_nums[0];
}
void UpdateInputs(const paddle::Tensor &stop_flags,
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"input_ids", "input_ids_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputs));
.SetKernelFn(PD_KERNEL(UpdateInputes));

View File

@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
template <paddle::DataType D>
void AppendAttentionKernel(
std::vector<paddle::Tensor> AppendAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
@@ -60,7 +60,6 @@ void AppendAttentionKernel(
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
@@ -73,11 +72,7 @@ void AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
@@ -123,6 +118,27 @@ void AppendAttentionKernel(
} else {
qkv_out = qkv;
}
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
D,
qkv.place());
}
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
const paddle::Tensor& lambda_batch_ids,
@@ -140,8 +156,8 @@ void AppendAttentionKernel(
key_cache,
value_cache,
attn_mask,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
@@ -207,10 +223,7 @@ void AppendAttentionKernel(
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
};
if (qkv_out_scales) {
@@ -257,6 +270,54 @@ void AppendAttentionKernel(
if (speculate_decoder) {
if (qkv_out_scales) {
SpeculateWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
}
} else {
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
@@ -278,12 +339,9 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
@@ -305,64 +363,7 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
} else {
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
}
}
@@ -391,6 +392,8 @@ void AppendAttentionKernel(
cudaStreamWaitEvent(main_stream, decoder_event);
}
}
return {fmha_out, qkv_out};
}
std::vector<paddle::Tensor> AppendAttention(
@@ -426,11 +429,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -465,60 +464,8 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
// template dtype generation
phi::DataType dtype_id;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dtype_id = phi::DataType::BFLOAT16;
break;
} else if (compute_dtype == "fp16") {
dtype_id = phi::DataType::FLOAT16;
break;
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
// fmha_out generation, rewrite from AppendAttentionKernel
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
} else{
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
dtype_id,
qkv.place());
}
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
@@ -540,7 +487,6 @@ std::vector<paddle::Tensor> AppendAttention(
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
@@ -553,11 +499,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
@@ -572,198 +514,35 @@ std::vector<paddle::Tensor> AppendAttention(
speculate_max_draft_token_num,
causal,
speculate_decoder);
};
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (dtype_id){
case phi::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
return {fmha_out};
}
case phi::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
return {fmha_out};
}
default:
PD_THROW(
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
return dispatch_by_template(bp16_dtype);
} else if (compute_dtype == "fp16") {
return dispatch_by_template(fp16_dtype);
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
break;
}
}
return {paddle::Tensor{}};
}
void AppendAttentionWithOutput(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
AppendAttnMetaData meta_data;
const auto& qkv_dims = qkv.dims();
const auto& key_cache_dims = key_cache.dims();
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
// TODO: trick method support c4, add attr head_dims in the future
if (cache_quant_type_str == "cache_int4_zp") {
meta_data.head_dims *= 2;
}
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
break;
}
case paddle::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
break;
}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dispatch_by_template(bp16_dtype);
break;
} else if (compute_dtype == "fp16") {
dispatch_by_template(fp16_dtype);
break;
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
@@ -797,11 +576,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -825,7 +600,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
}
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}};
return {{token_num, num_heads * head_dim}, qkv_shape};
}
std::vector<paddle::DataType> AppendAttentionInferDtype(
@@ -861,11 +636,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -884,148 +655,32 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
if (compute_dtype == "bf16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
return {paddle::DataType::INT8};
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::BFLOAT16};
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
}
} else if (compute_dtype == "fp16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
return {paddle::DataType::INT8};
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::FLOAT16};
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
}
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
}
}
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const std::vector<int64_t>& fmha_out_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
return {fmha_out_shape};
}
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::DataType& fmha_out_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
return {fmha_out_dtype};
}
PD_BUILD_STATIC_OP(append_attention)
.Inputs({"qkv",
"key_cache",
@@ -1059,15 +714,11 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
paddle::Optional("kv_signal_data")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
@@ -1081,71 +732,7 @@ PD_BUILD_STATIC_OP(append_attention)
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
PD_BUILD_STATIC_OP(append_attention_with_output)
.Inputs({"qkv",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
"fmha_out",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
paddle::Optional("qkv_bias"),
paddle::Optional("qkv_out_scales"),
paddle::Optional("cache_k_quant_scales"),
paddle::Optional("cache_v_quant_scales"),
paddle::Optional("cache_k_dequant_scales"),
paddle::Optional("cache_v_dequant_scales"),
paddle::Optional("cache_k_zp"),
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));

View File

@@ -43,7 +43,6 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -52,7 +51,6 @@ __global__ void multi_query_append_attention_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -75,11 +73,6 @@ __global__ void multi_query_append_attention_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -148,7 +141,6 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -187,7 +179,7 @@ __global__ void multi_query_append_attention_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -253,16 +245,12 @@ __global__ void multi_query_append_attention_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -418,8 +406,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -428,14 +414,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -452,11 +436,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -523,7 +502,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -561,9 +540,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -631,15 +611,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -905,7 +882,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -914,7 +890,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -964,7 +939,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -973,7 +947,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1088,18 +1061,12 @@ void MultiQueryAppendAttention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
false,
@@ -1137,9 +1104,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1148,13 +1112,11 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1199,8 +1161,8 @@ void MultiQueryAppendAttention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1210,9 +1172,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1221,13 +1180,11 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
@@ -1251,8 +1208,8 @@ void MultiQueryAppendAttention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1269,14 +1226,14 @@ void MultiQueryAppendAttention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1287,8 +1244,8 @@ void MultiQueryAppendAttention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -48,7 +48,6 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -57,7 +56,6 @@ __global__ void multi_query_append_attention_c4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -86,11 +84,6 @@ __global__ void multi_query_append_attention_c4_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -179,7 +172,6 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -256,7 +248,7 @@ __global__ void multi_query_append_attention_c4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -341,15 +333,12 @@ __global__ void multi_query_append_attention_c4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -516,8 +505,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -526,14 +513,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -556,11 +541,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -647,7 +627,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -723,9 +703,10 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -807,15 +788,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1110,7 +1088,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1119,7 +1096,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1175,7 +1151,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1184,7 +1159,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1311,18 +1285,10 @@ void MultiQueryAppendC4Attention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1368,9 +1334,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1379,13 +1342,11 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1431,15 +1392,15 @@ void MultiQueryAppendC4Attention(
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1449,9 +1410,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1460,13 +1418,11 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1489,8 +1445,8 @@ void MultiQueryAppendC4Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1507,14 +1463,14 @@ void MultiQueryAppendC4Attention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1525,8 +1481,8 @@ void MultiQueryAppendC4Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -32,15 +32,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads]
const T *__restrict__ cache_v_scale, // [num_kv_heads]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -49,7 +48,6 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -58,7 +56,6 @@ __global__ void multi_query_append_attention_c8_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -88,40 +85,33 @@ __global__ void multi_query_append_attention_c8_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
@@ -189,7 +179,6 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -210,13 +199,6 @@ __global__ void multi_query_append_attention_c8_kernel(
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
@@ -234,7 +216,7 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -298,22 +280,10 @@ __global__ void multi_query_append_attention_c8_kernel(
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -330,15 +300,12 @@ __global__ void multi_query_append_attention_c8_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -346,7 +313,6 @@ __global__ void multi_query_append_attention_c8_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const int ori_kv_idx_base = kv_idx_base;
kv_idx_base += num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -365,18 +331,6 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -387,9 +341,7 @@ __global__ void multi_query_append_attention_c8_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -506,15 +458,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise=false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -523,8 +474,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -533,14 +482,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -563,39 +510,32 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
@@ -661,7 +601,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -686,13 +626,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
CAUSAL
@@ -709,7 +642,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -775,23 +708,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -807,16 +728,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -824,7 +741,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const uint32_t ori_kv_idx_base = kv_idx_base;
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -843,18 +759,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -865,9 +769,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -981,8 +883,7 @@ template <typename T,
uint32_t NUM_WARP_Q,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
void MultiQueryAppendC8Attention(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &qkv,
@@ -1040,8 +941,7 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
num_frags_z * 16 * sizeof(T) * 2;
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
@@ -1058,9 +958,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1078,9 +976,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1114,9 +1010,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1134,9 +1028,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1162,7 +1054,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1171,7 +1062,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1221,7 +1111,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1230,7 +1119,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1316,8 +1204,7 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
constexpr uint32_t smem_size =
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1334,9 +1221,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1354,9 +1239,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1371,17 +1254,10 @@ void MultiQueryAppendC8Attention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1398,9 +1274,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1418,9 +1292,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1446,9 +1318,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1457,13 +1326,11 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1510,8 +1377,8 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1521,9 +1388,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1532,13 +1396,11 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1556,8 +1418,8 @@ void MultiQueryAppendC8Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1574,14 +1436,14 @@ void MultiQueryAppendC8Attention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1592,8 +1454,8 @@ void MultiQueryAppendC8Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1655,7 +1517,6 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out) {
const auto token_num = meta_data.token_nums;
@@ -1664,7 +1525,6 @@ void CascadeAppendAttentionC8Kernel(
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim = meta_data.head_dims;
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
DISPATCH_CAUSAL(
causal,
@@ -1683,46 +1543,43 @@ void CascadeAppendAttentionC8Kernel(
BLOCK_SIZE,
{DISPATCH_BLOCKSHAPE_Q(
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
BLOCK_SIZE,
CAUSAL,
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL,
IsFP8,
IsDynamicC8>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale.get(),
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
out);
})})})})})})})
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
BLOCK_SIZE,
CAUSAL,
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL, IsFP8>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale.get(),
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
out);
})})})})})})
}

View File

@@ -384,113 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
}
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
}
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
@@ -923,8 +816,7 @@ template <uint32_t num_frags_x,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
@@ -968,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
#pragma unroll
@@ -1020,15 +905,12 @@ template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
bool IS_SYSTEM = false>
__device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_idx_base,
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
const uint32_t kv_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
float (*s_frag)[num_frags_z][8]) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -1042,21 +924,10 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
}
const bool out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
@@ -1064,7 +935,6 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
@@ -1208,9 +1078,7 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool is_scale_channel_wise = false, bool IsFP8=false>
__device__ __forceinline__ void compute_sfm_v_c8(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1252,28 +1120,16 @@ __device__ __forceinline__ void compute_sfm_v_c8(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
@@ -1300,9 +1156,7 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool is_scale_channel_wise = false, bool IsFP8=false>
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1346,28 +1200,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16

View File

@@ -103,7 +103,6 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -265,10 +264,9 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
} else if (cache_quant_type_str == "cache_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
@@ -301,7 +299,6 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_int4_zp") {

View File

@@ -18,189 +18,6 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
// Note(ZKK)
// This function is very easy!
// just make HeadDim data to be new HeadDim data!
template <typename T, int VecSize=8, int HEAD_DIM=128, int NUM_THREADS=32>
__device__ __forceinline__ void apply_rope(
const T* input,
const float* cos_emb,
const float* sin_emb,
T* output,
const int thread_id) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
#pragma unroll
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) {
Load<T, VecSize>(&input[head_bias], &src_vec);
const uint32_t emb_idx = head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(out_vec, &output[head_bias]);
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadKVT cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int half_head_size = head_size / 2;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
const uint32_t ori_idx =
start_token_idx * hidden_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
}
}
if (hi < (num_heads + kv_num_heads)) { // q k
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) { // q
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else { // k
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
} else {
// quant + write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
if (hi < num_heads + kv_num_heads) {
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
}
}
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
@@ -211,7 +28,7 @@ __global__ void append_decode_cache_T_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -317,7 +134,7 @@ __global__ void append_decode_cache_T_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -382,9 +199,8 @@ __global__ void append_decode_cache_T_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -438,6 +254,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -449,8 +266,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
@@ -497,9 +313,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -551,6 +366,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -566,8 +382,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
@@ -624,9 +439,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -674,8 +488,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
}
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
__global__ void append_decode_cache_int8_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
@@ -684,24 +498,21 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
T* __restrict__ cache_k_scale,
T* __restrict__ cache_v_scale,
const float* q_norm_weight,
const float* k_norm_weight,
const T* __restrict__ cache_k_scale,
const T* __restrict__ cache_v_scale,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d,
const float rms_norm_eps) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -722,18 +533,6 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
int cache_offset;
if (head_idx < num_heads) {
cache_offset = 0;
} else if (head_idx < num_heads + 2 * kv_num_heads) {
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
}
T *cache_k_scale_now = cache_k_scale + cache_offset;
T *cache_v_scale_now = cache_v_scale + cache_offset;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
@@ -753,11 +552,11 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -766,260 +565,13 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] =
static_cast<T>(tmp1);
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
}
// qk norm
if (q_norm_weight) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadOutScaleT q_norm_vec;
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
}
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
if (block_offset == 0) {
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
if (head_idx < num_heads + kv_num_heads) {
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim;
block_i < block_size;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&key_cache[tgt_idx + block_i * HeadDim]);
}
} else {
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
const int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
}
}
__syncwarp();
}
constexpr int K_VEC_SIZE = 4;
constexpr int HALF_K_VEC_SIZE = 2;
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
using LoadEmbT = AlignedVector<float, 1>;
LoadKVResT cache_vec;
LoadT src_vec1, src_vec2;
LoadBiasT out_vec1, out_vec2;
LoadEmbT cos_emb_vec1, cos_emb_vec2;
LoadEmbT sin_emb_vec1, sin_emb_vec2;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
T scale = T(1.0f);
const int k_head_idx = head_idx - num_heads;
const int v_head_idx = head_idx - num_heads - kv_num_heads;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
}
float input_left = static_cast<float>(src_vec1[0]);
float input_right = static_cast<float>(src_vec1[1]);
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec1[0] =
static_cast<T>(tmp1);
out_vec1[1] =
static_cast<T>(tmp2);
} else {
out_vec1[0] = src_vec1[0];
out_vec1[1] = src_vec1[1];
}
// rope
input_left = static_cast<float>(src_vec2[0]);
input_right = static_cast<float>(src_vec2[1]);
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec2[0] = static_cast<T>(tmp1);
out_vec2[1] = static_cast<T>(tmp2);
} else {
out_vec2[0] = src_vec2[0];
out_vec2[1] = src_vec2[1];
}
if (k_norm_weight) {
if (head_idx < num_heads + kv_num_heads) {
LoadOutScaleT k_norm_vec1, k_norm_vec2;
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
// qk norm
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
}
}
}
// reduce max, 1 head per warp
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(out_vec1[i]));
local_max = __hmax(local_max, __habs(out_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}
scale = __hdiv(448, local_max);
if (lane_id == 0) {
if (head_idx < num_heads + kv_num_heads) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}
#pragma unroll
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
}
if (head_idx < num_heads + kv_num_heads) {
const int start_block_16 =
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
const uint32_t tgt_cache_idx =
block_idx * kv_num_heads * block_size * HeadDim +
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
} else {
const uint32_t base_tgt_cache_idx =
block_idx * kv_num_heads * HeadDim * block_size +
kv_head_idx * HeadDim * block_size +
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
block_offset % 8 / 2 * 4 // per 4
+ block_offset % 16 / 8 * 2 // per 2
+ block_offset % 2; // per 1
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
value_cache[tgt_cache_idx1] = cache_vec[0];
value_cache[tgt_cache_idx2] = cache_vec[1];
value_cache[tgt_cache_idx3] = cache_vec[2];
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
__global__ void append_decode_cache_int8_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const T* __restrict__ cache_k_scale,
const T* __restrict__ cache_v_scale,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane_id = tid % 32;
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
const int* block_table_now = nullptr;
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
if (head_idx < num_heads) {
// q
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
apply_rope<T, VecSize, HeadDim, 32>(
qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
@@ -1081,11 +633,10 @@ __global__ void append_decode_cache_int8_rope_kernel(
const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
if constexpr (!is_scale_channel_wise) {
scale = __ldg(&cache_k_scale[kv_head_idx]);
}
@@ -1194,6 +745,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1211,8 +763,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1262,10 +813,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -1358,11 +908,10 @@ __global__ void append_decode_cache_int8_rope_kernel(
const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
if constexpr (!is_scale_channel_wise) {
scale = __ldg(&cache_k_scales[kv_head_idx]);
}
@@ -1498,6 +1047,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1511,8 +1061,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1560,9 +1109,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -1643,11 +1191,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1799,7 +1346,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1817,8 +1364,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1878,10 +1424,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -1989,11 +1533,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -2196,7 +1739,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2212,8 +1755,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -2237,18 +1779,43 @@ __global__ void append_decode_cache_int4_rope_kernel(
if (head_idx < num_heads) {
// q
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
uint32_t emb_offset = write_seq_id * half_head_size;
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
apply_rope<T, VecSize, HeadDim, 32>(
qkv_now,
cos_emb + emb_offset,
sin_emb + emb_offset,
qkv_out_now,
lane_id);
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
out_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
out_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
@@ -2307,11 +1874,10 @@ __global__ void append_decode_cache_int4_rope_kernel(
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -2468,7 +2034,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2488,8 +2054,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -2538,9 +2103,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -2627,11 +2191,10 @@ __global__ void append_decode_cache_int4_rope_kernel(
&out_scale_vec2);
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -2799,7 +2362,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2815,8 +2378,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -2863,9 +2425,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// dequant + add_bias + rope
@@ -2946,11 +2507,10 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx], &right_src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx + 8], &right_src_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx + 8],
@@ -3172,7 +2732,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -3192,8 +2752,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -3251,9 +2810,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
&right_out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// dequant + add_bias + rope
@@ -3362,11 +2920,10 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
&right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx + 8],

View File

@@ -15,73 +15,13 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -117,6 +57,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -130,8 +71,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
@@ -139,6 +79,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -150,8 +91,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -162,6 +102,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -184,6 +125,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -207,6 +149,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -239,6 +182,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -254,8 +198,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int8_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -264,6 +207,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -277,8 +221,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -289,6 +232,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -304,8 +248,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -314,6 +257,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -327,8 +271,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
}
@@ -339,6 +282,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -373,6 +317,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -390,8 +335,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int4_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -400,6 +344,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -415,8 +360,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -427,6 +371,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -444,8 +389,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -454,6 +398,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -469,8 +414,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
}
@@ -480,6 +424,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -496,10 +441,7 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -522,15 +464,79 @@ void DecoderWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_decode_cache_rope_qk_norm(
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -540,6 +546,12 @@ void DecoderWriteCacheWithRoPEKernel(
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
@@ -549,246 +561,84 @@ void DecoderWriteCacheWithRoPEKernel(
bsz,
stream,
use_neox_rotary_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
nullptr,
nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
}
@@ -800,6 +650,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -816,10 +667,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -829,6 +677,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -845,10 +694,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -857,6 +703,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -873,10 +720,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -885,6 +729,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -901,7 +746,4 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -23,6 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -39,6 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -33,8 +33,7 @@ __global__ void VariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadScaleT = AlignedVector<float, VecSize>;
@@ -65,7 +64,6 @@ __global__ void VariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias;
const int64_t base_idx = token_idx * 3 * hidden_size + bias_idx;
Load<int, VecSize>(&qkv[base_idx], &src_vec);
@@ -74,8 +72,8 @@ __global__ void VariableLengthRotaryKernel(
}
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
if (qkv_id < 2) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -117,8 +115,7 @@ __global__ void VariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -145,12 +142,11 @@ __global__ void VariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx = token_idx * 3 * hidden_size +
qkv_id * hidden_size + hi * last_dim + h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -181,8 +177,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadScaleT = AlignedVector<float, VecSize>;
@@ -216,7 +211,6 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int bias_idx_left =
qkv_id * full_hidden_size + hi * last_dim + h_bias;
const int bias_idx_right = bias_idx_left + half_lastdim;
@@ -231,8 +225,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], &left_out_scale_vec);
Load<float, VecSize>(&qkv_out_scales[bias_idx_right], &right_out_scale_vec);
if (qkv_id < 2) {
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -275,8 +269,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
@@ -304,7 +297,6 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int base_idx_left = token_idx * 3 * full_hidden_size +
qkv_id * full_hidden_size + hi * last_dim +
h_bias;
@@ -312,8 +304,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
@@ -366,7 +358,7 @@ __global__ void GQAVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
const int ori_bi = batch_id_per_token[token_idx];;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
@@ -375,7 +367,6 @@ __global__ void GQAVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t bias_idx = hi * last_dim + h_bias;
const int64_t base_idx = token_idx * offset + bias_idx;
Load<int, VecSize>(&qkv[base_idx], &src_vec);
@@ -384,8 +375,8 @@ __global__ void GQAVariableLengthRotaryKernel(
}
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -414,97 +405,6 @@ __global__ void GQAVariableLengthRotaryKernel(
}
}
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryQKNormKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps
) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim;
for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) {
int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize;
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
}
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < q_num_head) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
}
}
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryKernel(
const T *qkv,
@@ -614,7 +514,6 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t bias_idx = hi * last_dim + h_bias;
const int64_t base_idx = token_idx * offset + bias_idx;
Load<int, VecSize>(&qkv[base_idx], &src_vec);
@@ -622,8 +521,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
float input_left = static_cast<float>(src_vec[2 * i]);
@@ -700,15 +599,14 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t bias_idx = hi * last_dim + h_bias;
const int64_t base_idx = token_idx * offset + bias_idx;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = qkv_biases ? static_cast<float>(src_vec[2 * i]+ bias_vec[2 * i]) : static_cast<float>(src_vec[2 * i]);
@@ -756,8 +654,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadScaleT = AlignedVector<float, VecSize>;
@@ -787,7 +684,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int bias_idx_left = hi * last_dim + h_bias;
const int bias_idx_right = bias_idx_left + half_lastdim;
const int base_idx_left =
@@ -802,8 +698,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], &left_out_scale_vec);
Load<float, VecSize>(&qkv_out_scales[bias_idx_right], &right_out_scale_vec);
if (hi < (q_num_head + kv_num_head)) {
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -849,8 +745,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
@@ -874,7 +769,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int base_idx_left =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -882,8 +776,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
@@ -1232,411 +1126,6 @@ __global__ void append_write_cache_kv_c8_qkv(
}
}
template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t NUM_WARPS,
bool is_need_kv_quant,
bool IsFP8 = true>
__global__ void append_write_cache_kv_c8_qkv_dynamic(
uint8_t *__restrict__ cache_k,
uint8_t *__restrict__ cache_v,
const T *__restrict__ qkv_input,
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids,
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t pad_len = BLOCK_SIZE;
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const T cache_k_scale = cache_k_scales[kv_head_idx];
const T cache_v_scale = cache_v_scales[kv_head_idx];
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids[btid];
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
if (seq_len_this_time <= 0) {
return;
}
const int *block_table_now = nullptr;
block_table_now = block_tables + batch_id * max_blocks_per_seq;
const uint32_t num_rows_per_block =
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
const uint32_t start_len = seq_lens_decoder[batch_id];
const uint32_t bf_pad_len = start_len % pad_len;
const uint32_t start_len_pad = start_len - bf_pad_len;
const uint32_t end_len = start_len + seq_len_this_time;
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_scale_smem[BLOCK_SIZE];
if (tile_start >= start_len) {
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
// reset k
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
uint32_t tgt_idx =
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
tid % num_vecs_per_head_k * KV_VEC_SIZE;
for (int block_i = tid / num_vecs_per_head_k;
block_i < BLOCK_SIZE;
block_i += num_token_each_time_k) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&cache_k[tgt_idx + block_i * HEAD_DIM]);
}
// reset v
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
tgt_idx =
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
tid % num_vecs_per_head_v * KV_VEC_SIZE;
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
block_i += num_token_each_time_v) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
}
}
smem_t k_smem(k_smem_ori);
smem_t v_smem(v_smem_ori);
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
/*
0 | 1
2 | 3
*/
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
/*
0 | 2
1 | 3
*/
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, wid * num_frags_v * 2 + tid / 16);
// load kv gmem to smem
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
tile_id * num_rows_per_block +
wid * num_frags_z * 16 + tid / 8;
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
(num_heads + kv_head_idx) * kv_h_stride +
tid % 8 * num_elems_per_128b<T>();
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
tid % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y / 4;
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
if (chunk_start >= start_len && chunk_start < end_len) {
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
}
kv_smem_offset_w =
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
k_read_idx += 8 * num_elems_per_128b<T>();
v_read_idx += 8 * num_elems_per_128b<T>();
}
kv_smem_offset_w =
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
2 * num_frags_y;
chunk_start += 4;
k_read_idx +=
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
v_read_idx +=
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
}
}
commit_group();
wait_group<0>();
__syncthreads();
// reduce scale
// 16 rows per warp
uint32_t kv_reduce_frag[4];
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
T k_local_max_value[num_frags_z * 2];
T v_local_max_value[num_frags_z * 2];
#pragma unroll
for (int i = 0; i < num_frags_z * 2; i++) {
k_local_max_value[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < num_frags_z * 2; i++) {
v_local_max_value[i] = -INFINITY;
}
const int num_kv_heads = gridDim.z;
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
T *cache_k_scale_now = cache_k_scales + scale_offset;
T *cache_v_scale_now = cache_v_scales + scale_offset;
// k scale
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// reduce per thread, 4 threads each row
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
#pragma unroll
for (int i = 0; i < 4; i++) {
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
}
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
// reduce per row
for (int i = 0; i < 2; i++) {
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
// used for quant
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
}
// store
if (tid % 4 == 0) {
const int offset_now = wid * num_frags_z * 16 + tid / 4;
// used for dequant
if (tile_start + offset_now >= start_len) {
if (tile_start + offset_now < end_len) {
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
} else {
cache_k_scale_now[offset_now] = 0;
}
}
if (tile_start + offset_now + 8 >= start_len) {
if (tile_start + offset_now + 8 < end_len) {
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
} else {
cache_k_scale_now[offset_now + 8] = 0;
}
}
}
__syncthreads();
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
}
// v scale
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// reduce per thread, 4 threads each row
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
#pragma unroll
for (int i = 0; i < 4; i++) {
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
}
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
// reduce per row
for (int i = 0; i < 2; i++) {
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
}
// store
if (tid % 4 == 0) {
const int offset_now = wid * num_frags_z * 16 + tid / 4;
// used for dequant
if (tile_start + offset_now >= start_len) {
if (tile_start + offset_now < end_len) {
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
} else {
cache_v_scale_now[offset_now] = 0;
v_scale_smem[offset_now] = 0;
}
}
if (tile_start + offset_now + 8 >= start_len) {
if (tile_start + offset_now + 8 < end_len) {
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
} else {
cache_v_scale_now[offset_now + 8] = 0;
v_scale_smem[offset_now + 8] = 0;
}
}
}
__syncthreads();
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
}
__syncthreads();
// mask, quant, store
using LoadKVT = AlignedVector<uint8_t, 4>;
LoadKVT cache_vec1;
LoadKVT cache_vec2;
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
uint32_t kv_frag[4];
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t write_b_stride = HEAD_DIM;
const uint32_t write_d_stride = BLOCK_SIZE;
uint32_t k_write_idx = block_id * write_n_stride +
kv_head_idx * write_h_stride +
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t k_write_idx_now = k_write_idx_now_z +
fy % 2 * 8 * write_b_stride +
fy / 2 * 32; // + fy % 2 * 16;
// load
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// quant
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
if (bf_pad_len != 0) {
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
}
#pragma unroll
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
uint8_t uint_quant_value;
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
chunk_start_k + (v_id / 4) * 8 < end_len) {
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
} else {
uint_quant_value = 0;
}
if (bf_pad_len != 0) {
if (v_id < 4) {
cache_vec1[v_id] |= uint_quant_value;
} else {
cache_vec2[v_id % 4] |= uint_quant_value;
}
} else {
if (v_id < 4) {
cache_vec1[v_id] = uint_quant_value;
} else {
cache_vec2[v_id - 4] = uint_quant_value;
}
}
}
// store
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
k_smem_offset_r =
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
2 * num_frags_y;
chunk_start_k += 16;
}
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
uint32_t v_write_idx = block_id * write_n_stride +
kv_head_idx * write_h_stride +
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
T v_scales[num_frags_z_v * 4];
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
const int offset = v_i * 16;
const int t_offset = tid % 4 * 2;
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
}
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
uint32_t v_write_idx_now = v_write_idx_now_v +
fz % 2 * 8 * write_d_stride +
fz / 2 * 32; // + fz % 2 * 16;
// load
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
// quant
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
if (bf_pad_len != 0) {
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
}
#pragma unroll
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
uint8_t uint_quant_value;
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
// store now
} else {
uint_quant_value = 0;
}
if (bf_pad_len != 0) {
if (v_id < 4) {
cache_vec1[v_id] |= uint_quant_value;
} else {
cache_vec2[v_id % 4] |= uint_quant_value;
}
} else {
if (v_id < 4) {
cache_vec1[v_id] = uint_quant_value;
} else {
cache_vec2[v_id % 4] = uint_quant_value;
}
}
}
// store
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
chunk_start_v += 16;
v_smem_offset_r =
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
}
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
v_smem_offset_r, wid * num_frags_v + fy) -
16 * num_frags_z_v * num_vecs_per_head;
chunk_start_v -= 16 * num_frags_z_v;
}
}
// Write Cache KV in Append
template <typename T,
uint32_t num_frags_y,
@@ -2023,8 +1512,7 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
} else {
VariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
@@ -2039,8 +1527,7 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
}
} else {
const float *cos_emb = rotary_emb;
@@ -2061,8 +1548,7 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
} else {
NeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
@@ -2077,72 +1563,11 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
}
}
}
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_norm_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
const float *qkv_out_scales, // [3, num_head, dim_head]
const T *qkv_bias,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
const int num_heads,
const int kv_num_heads,
const int seq_len,
const int input_output_len,
const int dim_head,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false,
const float *q_norm_weight = nullptr,
const float *k_norm_weight = nullptr,
const float rms_norm_eps = 1e-6) {
int64_t elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
<<<grid_size, Block_Size, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
@@ -2237,8 +1662,7 @@ void gqa_rotary_qk_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
dim_head);
} else {
GQANeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
@@ -2256,8 +1680,7 @@ void gqa_rotary_qk_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
dim_head);
}
}
}
@@ -2411,11 +1834,10 @@ void CascadeAppendWriteCacheKVC8QKV(
int num_blocks_x_cpu,
int max_seq_len,
bool is_scale_channel_wise,
const std::string& cache_quant_type,
const bool is_fp8,
cudaStream_t &stream,
paddle::Tensor *cache_k_out,
paddle::Tensor *cache_v_out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
@@ -2433,77 +1855,49 @@ void CascadeAppendWriteCacheKVC8QKV(
dim3 blocks(32, num_warps);
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
if (cache_quant_type != "block_wise_fp8") {
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, false>;
if (cache_quant_type == "cache_fp8") {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
}
if (is_scale_channel_wise) {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
false>;
}
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
qkv.data<T>(),
cache_k_scale.data<T>(),
cache_v_scale.data<T>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
} else {
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, false>;
if (is_fp8) {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
}
if (is_scale_channel_wise) {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
false>;
}
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
qkv.data<T>(),
cache_k_scale.data<T>(),
cache_v_scale.data<T>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
}
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>

View File

@@ -46,10 +46,7 @@ void EncoderWriteCacheWithRopeKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
auto token_num = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
@@ -59,9 +56,28 @@ void EncoderWriteCacheWithRopeKernel(
is_scale_channel_wise = true;
}
if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
gqa_rotary_qk_norm_variable(
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -79,80 +95,31 @@ void EncoderWriteCacheWithRopeKernel(
head_dim,
stream,
use_neox_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
rope_3d);
} else {
PD_THROW(
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
} else {
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
}
}
const uint32_t block_size = meta_data.block_size;
if (cache_quant_type_str == "none") {
@@ -167,7 +134,7 @@ void EncoderWriteCacheWithRopeKernel(
stream,
key_cache_out,
value_cache_out);
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
@@ -187,7 +154,7 @@ void EncoderWriteCacheWithRopeKernel(
num_blocks,
max_seq_len,
is_scale_channel_wise,
cache_quant_type_str,
cache_quant_type_str == "cache_fp8",
stream,
key_cache_out,
value_cache_out);

View File

@@ -191,36 +191,26 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
}
}
void GetBlockShapeAndSplitKVBlock(
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num)
{
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0];
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor_gpu, bsz);
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
max_len_tensor, bsz);
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
// max_just_dec_merged_len_this_time, max_system_len,
// max_just_dec_len_without_system
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
auto max_len_cpu_ptr = max_len_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
int max_enc_len_this_time = max_len_cpu_ptr[1];
int max_dec_len_this_time = max_len_cpu_ptr[2];
@@ -230,7 +220,16 @@ void GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor decoder_batch_ids;
paddle::Tensor decoder_tile_ids_per_batch;
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
@@ -238,14 +237,17 @@ void GetBlockShapeAndSplitKVBlock(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
@@ -256,12 +258,16 @@ void GetBlockShapeAndSplitKVBlock(
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
@@ -269,58 +275,108 @@ void GetBlockShapeAndSplitKVBlock(
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
}
if (max_just_dec_len_this_time > 0) {
// Clear buffer
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}
return {encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/,
max_len_cpu};
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype) {
return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape) {
std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
{1},
{8}};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_x_cpu",
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({
})
.Attrs({
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"
})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
.Outputs({paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks"),
paddle::Optional("decoder_batch_ids"),
paddle::Optional("decoder_tile_ids_per_batch"),
paddle::Optional("decoder_num_blocks"),
paddle::Optional("max_len_kv"), "set_max_lengths"})
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
"group_size: int", "block_size: int",
"decoder_step_token_num: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));

View File

@@ -37,8 +37,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -63,7 +62,6 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -82,8 +80,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -120,7 +118,6 @@ void gqa_rotary_qk_split_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
@@ -149,8 +146,7 @@ void gqa_rotary_qk_split_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
dim_head);
}
template <typename T,
@@ -217,7 +213,7 @@ __global__ void append_cache_kv_c16(
// load k_smem 64 rows 128 cols
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -235,7 +231,7 @@ __global__ void append_cache_kv_c16(
// deal k_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
uint32_t col_idx = fy * 16 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// layout
@@ -278,7 +274,7 @@ __global__ void append_cache_kv_c16(
// load v_smem 64 rows 128 cols
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -296,7 +292,7 @@ __global__ void append_cache_kv_c16(
// deal v_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
uint32_t col_idx = fy * 16 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
// layout
@@ -400,7 +396,7 @@ __global__ void append_cache_kv_c8(
// load v_smem 64 rows, 128 cols
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -418,7 +414,7 @@ __global__ void append_cache_kv_c8(
// deal k_smem 64 rows, 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
uint32_t col_idx = fy * 32 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
// layout
@@ -466,7 +462,7 @@ __global__ void append_cache_kv_c8(
tid % 4 * num_elems_per_128b<CacheT>();
// load v_smem 128 rows 64 cols
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -485,7 +481,7 @@ __global__ void append_cache_kv_c8(
// deal v_smem 128 rows 64 cols
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
// layout
@@ -590,9 +586,9 @@ __global__ void append_cache_kv_c4(
#pragma unroll
for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) {
cache_k_scale_smem[i] = cache_k_scale_now[i];
cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast<T>(136.f);
cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast<T>(136.f);
cache_v_scale_smem[i] = cache_v_scale_now[i];
cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast<T>(136.f);
cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast<T>(136.f);
}
smem_t k_smem(smem);
@@ -614,7 +610,7 @@ __global__ void append_cache_kv_c4(
// load k_smem 64 rows 128 cols
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -632,7 +628,7 @@ __global__ void append_cache_kv_c4(
// deal k_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
uint32_t col_idx = fy * 64 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
@@ -644,25 +640,25 @@ __global__ void append_cache_kv_c4(
convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]);
if (row_idx < end_idx) {
k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
}
if (row_idx + 8 < end_idx) {
k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
}
col_idx += 32;
}
@@ -685,7 +681,7 @@ __global__ void append_cache_kv_c4(
tid % 2 * num_elems_per_128b<CacheT>();
// load v_smem 128 rows 64 rows
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -704,7 +700,7 @@ __global__ void append_cache_kv_c4(
// deal v_smem 128 rows 64 cols
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
// layout
@@ -715,36 +711,36 @@ __global__ void append_cache_kv_c4(
convert_int4(frag_dq_T, v_frag[2 * i]);
convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]);
if (kv_idx < end_idx) {
v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 1 < end_idx) {
v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 8 < end_idx) {
v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 9 < end_idx) {
v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 16 < end_idx) {
v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 17 < end_idx) {
v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 24 < end_idx) {
v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 25 < end_idx) {
v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
kv_idx += 32;
}
@@ -894,8 +890,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const std::string& cache_quant_type,
const bool rope_3d) {
const std::string& cache_quant_type) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -958,34 +953,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(
key_cache,
value_cache,
cache_k_dequant_scales.get(),
cache_v_dequant_scales.get(),
cache_k_zp.get(),
cache_v_zp.get(),
seq_lens_this_time,
seq_lens_decoder,
cu_seqlens_k,
block_tables,
cache_batch_ids,
cache_tile_ids,
cache_num_blocks,
max_blocks_per_seq,
kv_num_heads,
cache_quant_type,
&k,
&v,
stream
);
}
// write cache
if (cache_quant_type == "none") {
CascadeAppendWriteCacheKVQKV<data_t>(
@@ -1000,7 +970,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
meta_data,
*const_cast<paddle::Tensor*>(&key_cache),
@@ -1018,7 +988,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
kv_num_blocks_data,
max_seq_len,
false, // is_scale_channel_wise
cache_quant_type,
cache_quant_type == "cache_fp8", // is_fp8
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
@@ -1068,6 +1038,30 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
}
}
}
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(
key_cache,
value_cache,
cache_k_dequant_scales.get(),
cache_v_dequant_scales.get(),
cache_k_zp.get(),
cache_v_zp.get(),
seq_lens_this_time,
seq_lens_decoder,
cu_seqlens_k,
block_tables,
cache_batch_ids,
cache_tile_ids,
cache_num_blocks,
max_blocks_per_seq,
kv_num_heads,
cache_quant_type,
&k,
&v,
stream
);
}
return {q, k, v, qkv_out};
}

View File

@@ -18,166 +18,6 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1, typename InT = T>
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int output_inner_dim,
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadInT src_vec;
LoadFloat scale_vec;
LoadT bias_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec;
LoadFloat k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
const int half_head_size = head_size / 2;
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
const int token_id = linear_index / hidden_size;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
const int write_q_idx =
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
}
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (qkv_out_scales) {
input_left *= scale_vec[2 * i];
input_right *= scale_vec[2 * i + 1];
}
if (qkv_biases) {
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
}
if (hi < num_heads + gqa_group_size) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
}
}
if (hi < (num_heads + gqa_group_size)) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
} else {
// write k/v
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
kv_head_idx * block_size * head_size +
block_offset * head_size + h_bias);
// write
if (hi < num_heads + gqa_group_size) {
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
}
}
}
}
template <int VecSize = 4, int HeadDim = 128>
__global__ void append_clear_cache_int8_block(
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
@@ -353,8 +193,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -414,9 +253,8 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -488,8 +326,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -553,9 +390,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * head_size + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -640,8 +476,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -687,9 +522,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
}
@@ -749,11 +583,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale;
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
} else {
scale = __ldg(&cache_v_scales[kv_head_idx]);
@@ -875,8 +708,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -925,9 +757,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
&left_out_scale_vec);
@@ -1022,11 +853,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1258,8 +1088,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1316,9 +1145,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -1407,11 +1235,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// &out_scale_vec2);
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -1604,8 +1431,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1755,11 +1581,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
&right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],

View File

@@ -15,77 +15,6 @@
#include "speculate_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
// rope + write
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope(const QKV_TYPE* qkv,
@@ -110,8 +39,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
@@ -145,8 +73,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_rope_kernel<T, PackSize>
<<<grid_size, threads_per_block, 0, stream>>>(
@@ -169,8 +96,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
@@ -199,8 +125,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -242,8 +167,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -267,8 +191,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
@@ -299,8 +222,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -344,8 +266,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -371,8 +292,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
template <typename T, typename QKV_TYPE>
@@ -393,15 +313,11 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -426,184 +342,142 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}
@@ -626,15 +500,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -656,15 +526,11 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -685,15 +551,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
@@ -716,12 +578,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -35,12 +35,8 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -56,7 +56,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -104,6 +103,5 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -99,6 +98,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -43,7 +43,4 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -27,7 +27,6 @@ struct AppendAttnMetaData {
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
const int *mask_offset = nullptr;
};
__forceinline__ __host__ __device__ int div_up(int a, int b) {
@@ -431,9 +430,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 14) { \
constexpr size_t GROUP_SIZE = 14; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
@@ -441,15 +437,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
if (is_dynamic_cfp8) { \
constexpr bool IsDynamicC8 = true; \
__VA_ARGS__ \
} else { \
constexpr bool IsDynamicC8 = false; \
__VA_ARGS__ \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
@@ -487,9 +474,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
if (causal) { \
constexpr bool CAUSAL = true; \
__VA_ARGS__ \
} else { \
constexpr bool CAUSAL = false; \
__VA_ARGS__ \
}
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
@@ -575,37 +559,3 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
convert_int8(result, source);
}
}
constexpr int kWarpSize = 32;
template<typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
WelfordCombine1(b_m2, m2);
}
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}

View File

@@ -77,54 +77,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
const float quant_min_bound, const float out_linear_in_scale,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int max_partition_size, const int encoder_max_partition_size,
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);
void AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids_per_batch,
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &fmha_out,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &qkv_bias,
const paddle::optional<paddle::Tensor> &qkv_out_scales,
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_k_zp,
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
@@ -154,8 +107,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type,
const bool rope_3d);
const std::string &cache_quant_type);
std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
@@ -172,29 +124,11 @@ paddle::Tensor FusedExpertMoeFunc(
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
std::vector<paddle::Tensor> MacheteMMKernel(
paddle::Tensor const& A, paddle::Tensor const& B,
paddle::optional<paddle::Tensor> const& maybe_group_scales,
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
paddle::optional<paddle::Tensor> const& maybe_token_scales,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule);
std::vector<paddle::Tensor> MachetePrepackBKernel(
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str);
std::vector<std::string> MacheteSupportedSchedules(
std::string const& a_type_str, std::string const& b_type_str);
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
const bool group_moe, const bool topk_only_mode);
std::vector<paddle::Tensor>
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
@@ -254,9 +188,7 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size);
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
@@ -299,25 +231,12 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
const int layer_id);
void GetBlockShapeAndSplitKVBlock(
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
@@ -347,12 +266,13 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &seq_lens,
const paddle::Tensor &end_ids,
const paddle::Tensor &next_tokens,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const bool beam_search);
void GetStopFlagsMultiSeqs(
const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids);
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
@@ -386,11 +306,9 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens);
const int block_size);
paddle::Tensor
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
@@ -400,7 +318,7 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
const paddle::Tensor &mm_token_num_len,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states);
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
const paddle::Tensor &topk_idx,
@@ -603,7 +521,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa);
@@ -686,7 +604,7 @@ void SpeculateVerify(
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -717,22 +635,6 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens);
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
@@ -747,20 +649,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
const int max_draft_tokens);
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -776,12 +664,9 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
@@ -790,8 +675,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
const bool splitwise_prefill);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
@@ -872,33 +756,6 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num,
const int max_draft_tokens);
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int64_t seed);
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
const paddle::Tensor &top_k);
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -952,7 +809,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attention
*/
m.def("append_attention", &AppendAttention, "append attention function");
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
/**
* gqa_rope_write_cache.cu
* gqa_rope_write_cache
@@ -984,7 +840,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
py::arg("gating_output"), py::arg("gating_correction_bias"),
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
py::arg("topk_only_mode"), "moe export dispatch function");
/**
* moe/fused_moe/ep_moe_prefill_func.cu
@@ -1008,33 +864,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
py::arg("block_size"),
"per token per block quant and padding transpose scale");
"per token per block quant and padding tranpose scale");
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
py::arg("recv_expert_count"), py::arg("block_size"),
"per token per block quant");
#ifdef ENABLE_MACHETE
/*machete/machete_mm.cu
* machete_mm
*/
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
py::arg("maybe_schedule"),
"machete mm function");
/*machete/machete_prepack_B.cu
* machete_prepack_B
*/
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
/*machete/machete_supported_schedules.cu
* machete_supported_schedules
*/
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
#endif
/**
* moe/fused_moe/moe_topk_select.cu
* moe_topk_select
@@ -1051,7 +886,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
/**
* moe/fused_moe/moe_expert_ffn_wint2.cu
* moe/fused_moe/moe_ffn_wint2.cu
* moe_expert_ffn_wint2
*/
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
@@ -1119,6 +954,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("set_stop_value_multi_ends", &GetStopFlagsMulti,
"update_inputs function");
/**
* stop_generation_multi_stop_seqs.cu
* set_stop_value_multi_seqs
*/
m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs,
"update_inputs function");
/**
* update_inputs.cu
@@ -1248,7 +1089,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
@@ -1256,12 +1097,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
@@ -1275,14 +1112,4 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
m.def("rejection_top_p_sampling", &TopPSamplingReject, "rejection_top_p_sampling function");
m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function");
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
}

View File

@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
@@ -163,12 +163,3 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
PD_BUILD_STATIC_OP(all_reduce)
.Inputs({"inp",
"out"})
.Outputs({"new_out"})
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
.SetInplaceMap({{"out", "new_out"}})
.SetKernelFn(PD_KERNEL(all_reduce));

View File

@@ -6,8 +6,6 @@
// clang-format off
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
#include "helper.h"
// clang-format on
/*

View File

@@ -133,18 +133,10 @@ public:
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value; // 64
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8
public:
// using Layout = layout::ColumnMajor;
// static constexpr int ElementsPerAccess = 16; // at least 4-bytes
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value; // 64
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>

View File

@@ -18,12 +18,14 @@
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
@@ -376,23 +378,38 @@ template <
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
@@ -424,23 +441,38 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock

View File

@@ -19,7 +19,7 @@
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
namespace cutlass {
namespace gemm {
@@ -379,23 +379,38 @@ template <
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
@@ -427,23 +442,38 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock

View File

@@ -1,182 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
/// Partial specialization:
///
/// A: row-major
/// B: uint2b_t, column-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, uint2b_t, layout::ColumnMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = uint2b_t;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access of B
static constexpr int kMaxThreadsForB =
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
static constexpr int kThreadsForB =
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
static int const kWarpThreadArrangementContiguousB =
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
static int const kWarpThreadArrangementStridedB =
kWarpSize / kWarpThreadArrangementContiguousB;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementA>::value, Shape::kK>;
// Shared memory layout
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementB>::value, Shape::kK>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Policy used to define MmaPipelined
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -1,246 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_core.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <typename ThreadblockShape, typename ElementT, int GroupSize>
struct DefaultQuantParamsIterators {
private:
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
static constexpr int kColumns = ThreadblockShape::kN;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
public:
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
IteratorThreadMap, kAlignment>;
using SmemIterator = Iterator;
};
template <typename ThreadblockShape, int GroupSize>
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
private:
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
static constexpr int kColumns =
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
public:
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
0, IteratorThreadMap, AccessType>;
using SmemIterator = Iterator;
};
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
struct DefaultWint2xMma;
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator, SharedMemoryClear>
{
public:
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint2b_t");
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
using ElementSuperScale = ElementA;
using ElementLocalScale = uint4b_t;
using ElementCodeScaleZp = float;
static constexpr int kGroupSize = 64;
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
private:
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
using IteratorShapeB = MatrixShape<
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
ThreadMapB::kThreads,
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
WarpArrangement::kStrided / kColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
AccessTypeB>;
private:
// Define iterators over tiles from extra quant params for B operand
using IteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::Iterator;
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
using IteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
public:
using QuantParamsAccessor = Wint2ParamsAccessor<
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
IteratorLocalScale, SmemIteratorLocalScale,
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
typename MmaCore::Shape,
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
kStages, QuantParamsAccessor, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -63,8 +63,8 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Size of extra quantized params
typename QuantParamsShape>
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
@@ -89,18 +89,10 @@ public:
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of warp-level GEMM operations per load for B
static constexpr int kWarpGemmIterationsPerLoadForB =
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
static constexpr int kWarpLoadIterationsForB =
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
/// Number of stages
static int const kStages = Stages;
@@ -112,6 +104,8 @@ public:
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
@@ -136,11 +130,20 @@ public:
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape of all quant params in shared memory
using QuantParamsShapeB = QuantParamsShape;
// w uint8; local_scale uint8;
constexpr static int kZippedRowsPerStages =
Shape::kK / 4 + (Shape::kK + 127) / 128;
// code_scale float; code_zp float; super_scale ElementB
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
sizeof_bits<typename Operator::ElementB>::value / 8;
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
public:
//
@@ -153,8 +156,12 @@ public:
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for extra quant params of B operand
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
/// Buffer for quanted B operand
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
/// Buffer for unzip B operand
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
operand_unzip_B;
public:
//
@@ -184,6 +191,14 @@ public:
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
CUTLASS_HOST_DEVICE
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
CUTLASS_HOST_DEVICE
typename Operator::ElementB *operand_unzip_B_ptr() {
return operand_unzip_B.data();
}
};
protected:

View File

@@ -45,8 +45,7 @@
#include "cutlass_extensions/arch/memory_copy_sm80.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -87,15 +86,15 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Accessor for extra quantized params
typename QuantParamsAccessor_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
public Wint2xMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
@@ -108,11 +107,8 @@ public:
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
/// Accessor for extra quantized params
using QuantParamsAccessor = QuantParamsAccessor_;
using QuantArguments = typename QuantParamsAccessor::Arguments;
static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK;
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
@@ -133,18 +129,6 @@ public:
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
using LayoutScale = layout::RowMajor;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using WarpDequantizer =
warp::MmaTensorOpWin2xDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
typename WarpTransformedFragmentB::Element,
LayoutScale,
QuantParamsAccessor::kGroupSize>;
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
@@ -190,37 +174,18 @@ public:
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale;
using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp;
using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale;
/// Temporary accumulator to facilitate staged-accumulation
FragmentC tmp_accum_;
/// Pair of A fragments used to overlap shared memory loads and math instructions
WarpTransformedFragmentA warp_frag_A_[2];
WarpLoadedFragmentA warp_loaded_frag_A_[2];
WarpTransformedFragmentA warp_transformed_frag_A_[2];
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_;
WarpTransformedFragmentB warp_frag_B_[2];
/// channel-wise quant params
FragmentCodeScaleZp warp_frag_code_scale_;
FragmentCodeScaleZp warp_frag_code_zp_;
FragmentSuperScale warp_frag_super_scale_;
/// group-wise quant params
FragmentLocalScale warp_frag_local_scale_;
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];
};
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool IsTileInterleaveLayout =
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
@@ -237,18 +202,17 @@ public:
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Accessor for extra quant params for B
QuantParamsAccessor quant_params_accessor_B_;
// Wint2 unzip operator
WarpDequantizer warp_dequantizer_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
uint8_t* column_wise_smem_ptr_B_;
uint8_t* smem_zipped_ptr_B_;
int smem_zipped_bytes_per_stage_B_;
public:
/// Construct from tensor references
@@ -262,15 +226,10 @@ public:
int warp_idx,
///< ID of each thread within a warp
int lane_idx
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
quant_params_accessor_B_.local_scale_ref(),
quant_params_accessor_B_.code_scale_ref(),
quant_params_accessor_B_.code_zp_ref(),
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0)
{
@@ -291,6 +250,11 @@ public:
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
}
/// Advance shared memory read-iterators to the next stage
@@ -302,22 +266,28 @@ public:
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
smem_read_stage_idx_ = 0;
}
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
}
/// Advance global memory read-iterators and shared memory write-iterators to the stage
template <typename TileDequanterB>
CUTLASS_DEVICE
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
void advance_smem_write_stage(
IteratorA &iterator_A,
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B)
{
// Advance global iterators
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
//iterator_B.add_tile_offset({1, 0});
tile_dequanter_B.AddTileOffset({1, 0});
// Advance shared iterators
smem_iterator_A_.add_tile_offset({0, 1});
smem_iterator_B_.add_tile_offset({1, 0});
//smem_iterator_B_.add_tile_offset({1, 0});
// Increment shared memory write stage index
++smem_write_stage_idx_;
@@ -325,7 +295,7 @@ public:
if (smem_write_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx_ = 0;
}
}
@@ -368,14 +338,9 @@ public:
}
}
template <bool GlobalToSharedB>
CUTLASS_DEVICE
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
}
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
@@ -395,14 +360,13 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, is_valid);
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, is_valid);
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
@@ -411,6 +375,7 @@ public:
++this->smem_iterator_B_;
}
}
__syncthreads();
}
CUTLASS_DEVICE
@@ -434,6 +399,8 @@ public:
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
@@ -444,12 +411,9 @@ public:
}
}
template <bool GlobalToSharedB, bool InitStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
@@ -469,23 +433,35 @@ public:
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
if (InitStage) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
} else {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
}
++iterator_B;
}
++this->smem_iterator_B_;
}
__syncthreads();
}
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
template <typename TileDequanterB>
CUTLASS_DEVICE
void prologue(
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
TileDequanterB &tile_dequanter_B,
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
{
// Issue several complete stages
@@ -500,18 +476,11 @@ public:
copy_tiles_and_advance_per_stage_A(iterator_A);
// Async copy zipped B to shared memory.
copy_tiles_and_advance_per_stage_B(iterator_B);
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
if (stage == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
} else {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
}
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
// Move to the next write stage
advance_smem_write_stage(iterator_A, iterator_B);
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
@@ -541,10 +510,6 @@ public:
++last_smem_iterator_A;
}
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
@@ -577,57 +542,57 @@ public:
}
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
template <typename TileDequanterB>
CUTLASS_DEVICE
void mac_loop_iter(
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
int stage)
{
const int mma_stage = stage - Base::kStages + 1;
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
}
// load next-tile of group-wise local_scale from shared memory
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
}
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
// Load the next warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
// dequantizes next warp-tile
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
// Unpack and dequant the first stage of B.
int unpack_stage = stage - Base::kStages + 2;
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, unpack_stage);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
}
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
}
// Execute the current warp-tile of MMA operations
if constexpr (Detail::kStagedAccumulation) {
if (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_
);
@@ -639,22 +604,22 @@ public:
} else {
warp_mma_(
accum,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
accum);
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
accum
);
}
// Except for the last warp-tile, all warp-tiles issue their share of
// global->shared fragment copies
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
if (warp_mma_k == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
}
}
@@ -663,15 +628,9 @@ public:
// - moves to the next global fetch stage
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
// Performs the last warp-tile's share of global->shared fragment copies
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
}
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
}
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
@@ -680,66 +639,69 @@ public:
gmem_wait();
// Move to the next global fetch stage
advance_smem_write_stage(iterator_A, iterator_B);
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
advance_smem_read_stage();
int byte_offset = quant_params_accessor_B_.advance_smem_read_stage();
warp_dequantizer_.add_pointer_offset(byte_offset);
// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
}
// The last warp-tile also converts the shared memory fragments used by
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
}
}
}
/// Perform the specified number of threadblock mainloop iterations of matrix
/// multiply-accumulate. Assumes prologue has been initiated.
template <typename TileDequanterB>
CUTLASS_DEVICE
void gemm_iters(
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args)
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
{
PipeState pipe_state;
// Unpack and dequant the first stage of B.
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
warp_dequantizer_.load(pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_);
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
++this->warp_tile_iterator_A_;
// Dequantize B to in register
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[0],
0,
0);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
if constexpr (Detail::kStagedAccumulation) {
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
++this->warp_tile_iterator_B_;
// Transform, if necessary, the first warp-tile's shared memory fragments
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[0],
pipe_state.warp_transformed_frag_B_[0],
pipe_state.warp_loaded_frag_A_[0],
pipe_state.warp_loaded_frag_B_[0]);
if (Detail::kStagedAccumulation) {
pipe_state.tmp_accum_.clear();
}
@@ -753,13 +715,13 @@ public:
accum,
iterator_A,
iterator_B,
mma_quant_args,
tile_dequanter_B,
gemm_k_iterations,
stage);
stage += 1;
}
if constexpr (Detail::kStagedAccumulation) {
if (Detail::kStagedAccumulation) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
@@ -799,12 +761,14 @@ public:
else
{
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
smem_read_stage_idx_ = smem_write_stage_idx_;
}
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
template <typename TileDequanterB>
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
@@ -815,13 +779,13 @@ public:
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterators for extra quant params for B
QuantArguments mma_quant_args,
///< pre-load and dequantize B to shared memory
TileDequanterB tile_dequanter_B,
///< initial value of accumulator
FragmentC const &src_accum) {
// Prologue (start fetching iterations of global fragments into shared memory)
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
// Wait until we have at least one completed global fetch stage
gmem_wait();
@@ -830,7 +794,7 @@ public:
accum = src_accum;
// Perform the MAC-iterations
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
}
};

View File

@@ -1,315 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/trace.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Original data type
typename T,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterators over super scales in global memory
typename IteratorSuperScale_,
/// Iterators over super scales in shared memory
typename SmemIteratorSuperScale_,
/// Iterators over local scales in global memory
typename IteratorLocalScale_,
/// Iterators over local scales in shared memory
typename SmemIteratorLocalScale_,
/// Iterators over code scales and zps in global memory
typename IteratorCodeScaleZp_,
/// Iterators over code scales and zps in shared memory
typename SmemIteratorCodeScaleZp_,
/// Number of stages,
int Stages_,
/// Group size for quantization
int GroupSize_>
class Wint2ParamsAccessor {
public:
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
using ElementType = T;
using Shape = Shape_;
using IteratorSuperScale = IteratorSuperScale_;
using SmemIteratorSuperScale = SmemIteratorSuperScale_;
using IteratorLocalScale = IteratorLocalScale_;
using SmemIteratorLocalScale = SmemIteratorLocalScale_;
using IteratorCodeScaleZp = IteratorCodeScaleZp_;
using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_;
constexpr static int kStages = Stages_;
constexpr static int kGroupSize = GroupSize_;
using ElementSuperScale = typename IteratorSuperScale::Element;
using LayoutSuperScale = typename IteratorSuperScale::Layout;
/// local_scale uint4 and group-wise
using ElementLocalScale = typename IteratorLocalScale::Element;
using LayoutLocalScale = typename IteratorLocalScale::Layout;
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
"local_scale's type must be uint4b_t.");
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
/// 2 uint4b_t values are stored in a single uint8_t
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
constexpr static int kLocalScaleRows =
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
using SmemElement = uint8_t;
constexpr static int kSmemRows =
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
constexpr static int kSmemColumns = Shape::kN;
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
constexpr static int kSuperScaleSmemOffset = 0;
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
struct Arguments {
IteratorSuperScale iterator_super_scale;
IteratorLocalScale iterator_local_scale;
IteratorCodeScaleZp iterator_code_scale;
IteratorCodeScaleZp iterator_code_zp;
int local_scale_pointer_offset;
CUTLASS_DEVICE
Arguments(IteratorSuperScale iterator_super_scale,
IteratorLocalScale iterator_local_scale,
IteratorCodeScaleZp iterator_code_scale,
IteratorCodeScaleZp iterator_code_zp,
int local_scale_pointer_offset)
: iterator_super_scale(iterator_super_scale),
iterator_local_scale(iterator_local_scale),
iterator_code_scale(iterator_code_scale),
iterator_code_zp(iterator_code_zp),
local_scale_pointer_offset(local_scale_pointer_offset) {}
};
private:
//
// Data members
//
/// Begin address of shared memory
uint8_t* smem_pointer_;
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
SmemIteratorSuperScale smem_iterator_super_scale_;
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
SmemIteratorLocalScale smem_iterator_local_scale_;
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
CUTLASS_DEVICE
ElementSuperScale* get_super_scale_smem_ptr() {
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
}
CUTLASS_DEVICE
ElementLocalScale* get_local_scale_smem_ptr() {
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_scale_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_zp_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
}
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2ParamsAccessor(
///< prointer of shared memory
uint8_t* smem_pointer,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: smem_pointer_(smem_pointer),
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0) {}
CUTLASS_DEVICE
SuperTensorRef super_scale_ref() {
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
}
CUTLASS_DEVICE
LocalTensorRef local_scale_ref() {
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_scale_ref() {
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_zp_ref() {
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
template <bool IsFirstStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
if constexpr (IsFirstStage) {
// Load channel-wise super_scale to shared memory, which only needs to be done once.
typename IteratorSuperScale::Fragment tb_frag_super_scale;
tb_frag_super_scale.clear();
quant_args.iterator_super_scale.load(tb_frag_super_scale);
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
// Load channel-wise code_scale to shared memory, which only needs to be done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
tb_frag_code_scale.clear();
quant_args.iterator_code_scale.load(tb_frag_code_scale);
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
// Load channel-wise code_zp to shared memory, which only needs to be done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
tb_frag_code_zp.clear();
quant_args.iterator_code_zp.load(tb_frag_code_zp);
this->smem_iterator_code_zp_.store(tb_frag_code_zp);
}
if ((stage % kStagesPerLocalScaleLoad) == 0) {
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
using AccessType = typename IteratorLocalScale::AccessType;
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
quant_args.iterator_local_scale.set_iteration_index(0);
this->smem_iterator_local_scale_.set_iteration_index(0);
// Async Copy for local_scale
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
AccessType *dst_ptr =
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
auto gmem_ptr = quant_args.iterator_local_scale.get();
int const kSrcBytes =
sizeof_bits<typename IteratorLocalScale::Element>::value *
IteratorLocalScale::ThreadMap::kElementsPerAccess /
IteratorLocalScale::kAccessesPerVector / 8;
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
}
++quant_args.iterator_local_scale;
}
++this->smem_iterator_local_scale_;
}
}
CUTLASS_DEVICE
void advance_smem_write_stage(Arguments &quant_args) {
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
// Advance global iterators
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
// Advance shared iterators
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
}
// Increment shared memory write stage index
++smem_write_stage_idx_;
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
smem_write_stage_idx_ = 0;
}
}
CUTLASS_DEVICE
int advance_smem_read_stage() {
int byte_offset = 0;
++smem_read_stage_idx_;
if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
byte_offset = kLocalScaleRows * kSmemColumns;
}
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
smem_read_stage_idx_ = 0;
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
}
return byte_offset;
}
CUTLASS_DEVICE
int clear_mask(Arguments &quant_args, bool cond) {
quant_args.iterator_local_scale.clear_mask(cond);
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/gemm_coord.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
int Stages, int NumThreads, WintQuantMethod Method>
struct TileDequanter {
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
using QuantArguments = typename WeightQuantTraits::Arguments;
using UnzipAndDequantFunctor =
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
static constexpr bool kUseSharedMemory = true;
static constexpr int kRows = Rows;
static constexpr int kColumns = Columns;
static constexpr int kStages = Stages;
MmaElementT *out_smem_ptr{nullptr};
char *pointer{nullptr};
int64_t ldm{0};
cutlass::MatrixCoord tb_offset;
cutlass::MatrixCoord extent;
ScaleElementT *super_scale_ptr{nullptr};
cutlass::MatrixCoord tb_offset_scale;
QuantArguments quant_args;
int64_t block_start_rows[kStages];
bool need_preload{true};
UnzipAndDequantFunctor unzip_functor;
CUTLASS_DEVICE
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
const cutlass::MatrixCoord &extent,
const cutlass::MatrixCoord &tb_offset,
ScaleElementT *super_scale_ptr,
const cutlass::MatrixCoord &tb_offset_scale,
const QuantArguments &quant_args)
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
CUTLASS_DEVICE
MmaElementT *GetOutPtr() { return out_smem_ptr; }
CUTLASS_DEVICE
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
tb_offset.row() += tile_offset.row() * kRows;
tb_offset.column() += tile_offset.column() * kColumns;
tb_offset_scale.column() += tile_offset.column() * kColumns;
}
CUTLASS_DEVICE
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
if (tb_offset.row() >= extent.row() ||
tb_offset.column() >= extent.column()) {
return;
}
block_start_rows[stage % kStages] = tb_offset.row();
using ZippedT = typename WeightQuantTraits::WeightType;
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
tb_offset.column();
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
(tb_offset.row() / 128) * ldm +
tb_offset_scale.column();
const float *code_scale_ptr =
quant_args.code_scale_ptr + tb_offset_scale.column();
const float *code_zp_ptr =
quant_args.code_zp_ptr + tb_offset_scale.column();
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
scale_ptr, &args, ldm, need_preload);
need_preload = false;
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
CUTLASS_DEVICE
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int64_t block_start_row = block_start_rows[stage % kStages];
if (block_start_row >= extent.row()) {
return;
}
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -41,9 +41,12 @@
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass {
namespace gemm {
namespace warp {
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -78,7 +81,7 @@ private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2.
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory

View File

@@ -58,12 +58,15 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
@@ -294,235 +297,6 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
/// Specialization for B of uint2b_t.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Instruction shape to override shared memory iterators with
typename SharedMemoryInstructionShape_,
/// Number of partitions along K dimension
int PartitionsK_,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
class MmaTensorOpComputeBWithF16<
Shape_,
ElementA_,
LayoutA_,
uint2b_t,
LayoutB_,
ElementC_,
LayoutC_,
Policy_,
SharedMemoryInstructionShape_,
PartitionsK_,
AccumulatorsInRowMajor>
{
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = uint2b_t;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
&& ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
static_assert(platform::is_same<ElementA, half_t>::value
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Instruction shape to override shared memory iterators with
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
static_assert(
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
static_assert(
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
public:
/// Iterates over the A operand in memory
using IteratorA
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
kThreadCount, kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / kExpansionFactor>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Number of mma operations performed
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const
{
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
D = C;
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n],
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n],
ptr_D[m_serpentine + n * MmaIterations::kRow]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp

View File

@@ -1,442 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations
targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass {
namespace gemm {
namespace warp {
namespace detail {
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<bfloat16_t> {
using Type = __nv_bfloat16;
using DualType = __nv_bfloat162;
};
template <>
struct DataTypeTraits<half_t> {
using Type = __half;
using DualType = __half2;
};
template <typename T, int N, typename Enable = void>
struct LocalScaleConverter {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<T, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
constexpr uint32_t kLocalScaleMask = 0xf;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int32_t shifted_value = (static_cast<int32_t>(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask;
scale_frag[i] = static_cast<T>(shifted_value) * super_scale_frag[i];
}
}
};
template <int N>
struct LocalScaleConverter<half_t, N, typename platform::enable_if<N % 4 == 0>::type> {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<half_t, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
constexpr uint32_t MASK = 0x000f000f;
// 2^10 = 1024
constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400;
// -2^10 = -1024
constexpr uint32_t FP16_BIAS = 0xE400E400;
// 1.0
constexpr uint32_t FP16_ONE = 0x3C003C00;
__half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag);
__half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag);
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
int i4s = local_scale_ptr[i] >> shift_bit;
// unpack: 0, 1
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_FP16s_MAGIC_NUM);
// unpack: 2, 3
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_FP16s_MAGIC_NUM);
__half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0),
*reinterpret_cast<const __half2*>(&FP16_ONE),
*reinterpret_cast<const __half2*>(&FP16_BIAS));
__half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1),
*reinterpret_cast<const __half2*>(&FP16_ONE),
*reinterpret_cast<const __half2*>(&FP16_BIAS));
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
}
}
};
template <int N>
struct LocalScaleConverter<bfloat16_t, N, typename platform::enable_if<N % 4 == 0>::type> {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<bfloat16_t, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
constexpr uint32_t MASK = 0x000F000F;
constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
constexpr uint32_t BF16_BIAS = 0xC300C300;
constexpr uint32_t BF16_ONE = 0x3F803F80;
__nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag);
__nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag);
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
int i4s = local_scale_ptr[i] >> shift_bit;
// unpack: 0, 1
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_BF16s_MAGIC_NUM);
// unpack: 2, 3
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_BF16s_MAGIC_NUM);
nv_bfloat162 scale0 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack0),
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
nv_bfloat162 scale1 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack1),
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
};
} // namespace detail
////////////////////////////////////////////////////////////////////////////////
template <
/// Matrix multiply operator
typename MmaOperator_,
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand,
/// Data type of Scale elements
typename ElementOperand_,
/// Layout of operand
typename Layout_,
/// Group size for quantization
int GroupSize_,
///
typename Enable = void>
class MmaTensorOpWin2xDequantizer {
//static_assert(false, "Not Supported!");
};
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
/// Data type of Scale elements
typename ElementOperand_,
/// Group size for quantization
int GroupSize_>
class MmaTensorOpWin2xDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
ElementOperand_,
layout::RowMajor,
GroupSize_>
//typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
// && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
static_assert(platform::is_same<ElementOperand_, half_t>::value || platform::is_same<ElementOperand_, bfloat16_t>::value,
"T must be fp16 or bf16");
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
/// Warp mma shape
using Shape = Shape_;
/// Type of mma operand
using ElementOperand = ElementOperand_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// Group size for quantization
static constexpr int kGroupSize = GroupSize_;
/// Type of input
using ElementB = typename MmaOperator::FragmentB::Element;
static_assert(platform::is_same<ElementB, uint2b_t>::value, "ElementB must be uint2b_t");
/// Type of the scales
using ElementLocalScale = uint4b_t;
using ElementSuperScale = ElementOperand;
using ElementCodeScaleZp = float;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn;
// use uint8_t to save 2 4-bits local scales
using FragmentLocalScale = Array<uint8_t, kWarpIterationsAlongN>;
using FragmentSuperScale = Array<ElementSuperScale, kWarpIterationsAlongN>;
using FragmentCodeScaleZp = Array<ElementCodeScaleZp, kWarpIterationsAlongN>;
/// Fragment to hold B data before Mma
using FragmentInput = Array<ElementB, MmaOperator::FragmentB::kElements>;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
static constexpr int kNumPacks = sizeof_bits<uint8_t>::value / sizeof_bits<ElementB>::value;
static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks);
static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor;
/// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points.
using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>;
using FragmentInputUnpack = typename Uint2Converter::result_type;
/// Fragment to hold internal scales before Mma
using FragmentScale = Array<ElementOperand, FragmentLocalScale::kElements>;
/// Fragment of dequantized B
using FragmentOutput = Array<ElementOperand, MmaOperator::FragmentB::kElements / kExpansionFactor>;
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, Layout>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, Layout>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, Layout>;
private:
//
// Data members
//
uint8_t* pointer_local_scale_;
ElementCodeScaleZp* pointer_code_scale_;
ElementCodeScaleZp* pointer_code_zp_;
ElementSuperScale* pointer_super_scale_;
//FragmentInputUnpack unpacked_frag_;
FragmentScale scale_frag_;
public:
CUTLASS_DEVICE
MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale,
LocalTensorRef smem_local_scale,
CodeTensorRef smem_code_scale,
CodeTensorRef smem_code_zp,
int warp_idx_n,
int lane_idx) {
int warp_offset = warp_idx_n * Shape::kN;
int quad = lane_idx / 4;
int thread_offset = warp_offset + quad;
pointer_super_scale_ = smem_super_scale.data() + thread_offset;
pointer_code_scale_ = smem_code_scale.data() + thread_offset;
pointer_code_zp_ = smem_code_zp.data() + thread_offset;
pointer_local_scale_ = reinterpret_cast<uint8_t *>(smem_local_scale.data()) + thread_offset;
}
/// Channel-wise params, need to load just once
CUTLASS_DEVICE
void load(FragmentCodeScaleZp& code_scale_frag,
FragmentCodeScaleZp& code_zp_frag,
FragmentSuperScale& super_scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN];
code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN];
}
}
/// Group-wise params, need to load multiple times
CUTLASS_DEVICE
void load(FragmentLocalScale& local_scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
}
}
CUTLASS_DEVICE
void dequantize(const FragmentLocalScale& local_scale_frag,
const FragmentCodeScaleZp& code_scale_frag,
const FragmentCodeScaleZp& code_zp_frag,
const FragmentSuperScale& super_scale_frag,
const FragmentInput& input_frag,
FragmentOutput& output_frag,
int tb_offset_k,
int warp_k_compute_offset) {
if constexpr (kUnpackInterval != 1) {
// unsupport now
arch::device_breakpoint();
}
typename Uint2Converter::source_type source_frag;
int in_offset = warp_k_compute_offset * kUnpackInterval;
uint8_t const* ptr_input = reinterpret_cast<uint8_t const*>(&input_frag);
uint8_t* ptr_source = reinterpret_cast<uint8_t *>(&source_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset];
}
FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag);
// dequantize local_scale
if (warp_k_compute_offset == 0) {
using LocalScaleConverter = detail::LocalScaleConverter<ElementOperand, FragmentLocalScale::kElements>;
// special for TileRows = 64
int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4;
LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift);
}
// unscale
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN;
using Type = typename detail::DataTypeTraits<ElementOperand>::Type;
using DualType = typename detail::DataTypeTraits<ElementOperand>::DualType;
Type* output_ptr = reinterpret_cast<Type *>(&output_frag);
DualType const* unpacked_ptr = reinterpret_cast<DualType const*>(&unpacked_frag);
DualType const* scale_ptr = reinterpret_cast<DualType const*>(&scale_frag_);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) {
int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK;
DualType scalex2 = scale_ptr[mma_n_iter / 2];
CUTLASS_PRAGMA_UNROLL
for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) {
DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter];
DualType scaled_value = __hmul2(unpacked_valuex2, scalex2);
output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x;
output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y;
}
}
}
/// Add an offset to pointer in units of elements.
/// Only group-wise params needs.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset) {
pointer_local_scale_ += offset;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@@ -39,25 +39,18 @@
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
#include "cutlass/trace.h"
namespace cutlass {
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
namespace cutlass
{
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
// This converter will uninterleave the data and subtract the bias while converting to the result type.
template <typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter;
struct FastInterleavedAndBiasedNumericArrayConverter
{
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
@@ -447,329 +440,6 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint2b_t, 16>
{
using result_type = Array<half_t, 16>;
using source_type = Array<uint2b_t, 16>;
using ScaleComputeT = float;
using code_type = Array<ScaleComputeT, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
ScaleComputeT new_code_zp = code_zp + 0.5f;
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert_impl(int32_t* decode_value)
{
result_type result;
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
static constexpr uint32_t MASK = 0x003F003F;
// 2^10 = 1024
static constexpr uint32_t EX = 0x64006400;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
h[3] = lop3<immLut>(q0, MASK, EX);
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
h[7] = lop3<immLut>(q1, MASK, EX);
// 1024 + 32 = 1056
static constexpr uint32_t SUB = 0x64206420;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
return convert(s, code_scale, code_zp);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint2b_t, 16>
{
using result_type = Array<bfloat16_t, 16>;
using source_type = Array<uint2b_t, 16>;
using ScaleComputeT = float;
using code_type = Array<ScaleComputeT, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
ScaleComputeT new_code_zp = code_zp + 0.5f;
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert_impl(int32_t* decode_value)
{
result_type result;
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
static constexpr uint32_t MASK = 0x003F003F;
// 2^7 = 128
static constexpr uint32_t EX = 0x43004300;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
h[3] = lop3<immLut>(q0, MASK, EX);
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
h[7] = lop3<immLut>(q1, MASK, EX);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
// 128 + 32 = 160
static constexpr uint32_t SUB = 0x43204320;
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
#else
// 1.0
static constexpr uint32_t MUL = 0x3F803F80;
// -160
static constexpr uint32_t ADD = 0xC320C320;
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD));
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
return convert(s, code_scale, code_zp);
}
};
template <typename T, int N>
struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, N>
{
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
static constexpr int kVecWidth = 16;
static_assert(!(N % kVecWidth), "N must be multiple of 16.");
using result_type = Array<T, N>;
using source_type = Array<uint2b_t, N>;
using code_type = Array<float, N / kVecWidth>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, kVecWidth>;
using vec_source = Array<scalar_source_type, kVecWidth>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / kVecWidth; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]);
}
return result;
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, Array<float, N / 4> const& code_scale, Array<float, N / 4> const& code_zp)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>;
result_type result;
using vec_result = typename Converter::result_type;
using vec_source = typename Converter::source_type;
using vec_code = typename Converter::code_type;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
vec_code const* code_scale_ptr = reinterpret_cast<vec_code const*>(&code_scale);
vec_code const* code_zp_ptr = reinterpret_cast<vec_code const*>(&code_zp);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / kVecWidth; ++i)
{
result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp)
{
return convert(s, code_scale, code_zp);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@@ -125,13 +125,10 @@ struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
static constexpr int32_t kNumPackedValues = 4;
static constexpr int32_t kPackedSize = 16;
using LocalScaleType = uint4b_t;
using CodeScaleZpType = float;
struct Arguments {
uint8_t *local_scale_ptr; // quanted 4-bits
float *code_scale_ptr;
float *code_zp_ptr;
const uint8_t *local_scale_ptr; // quanted 4-bits
const float *code_scale_ptr;
const float *code_zp_ptr;
};
CUTLASS_DEVICE

View File

@@ -117,7 +117,7 @@ class LeftGELUAndMul {
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to internal compute numeric type
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_to_compute;

View File

@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to internal compute numeric type
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_to_compute;

View File

@@ -92,7 +92,7 @@ class DualMmaBase {
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);

View File

@@ -43,6 +43,7 @@
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -774,54 +775,17 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
template <WintQuantMethod QuantMethod, typename dummy>
struct KernelRunner<QuantMethod, true, dummy> {
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments;
using QuantArguments = typename WeightQuantTraits::Arguments;
CUTLASS_DEVICE
static MmaQuantArguments prepare_quant_args(
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
weight_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
local_scale_ptr,
{(gemm_k + 127) / 128, gemm_n * 2},
thread_idx,
tb_offset_local_scale);
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_zp_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
MmaQuantArguments mma_quant_args(
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
return mma_quant_args;
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
QuantArguments quant_args;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
}
return quant_args;
}
CUTLASS_DEVICE
@@ -850,6 +814,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// LayoutB should be RowMajor
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
//
// Problem visitor.
//
@@ -876,6 +843,12 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
0);
// begin address offset for weight_scale.
ElementScale* weight_scale_ptr =
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Load element pointers. Exchange pointers and strides if working on
// the transpose
int64_t rows_to_jump = 0;
@@ -893,20 +866,42 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
// begin address offset for B for current problem_idx, totally num_experts problems
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix; // NOLINT
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B =
platform::is_same<layout::RowMajor, LayoutB>::value
? gemm_n
: gemm_k * kInterleave;
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
MmaElementB* smem_unzip_B_ptr = nullptr;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
}
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
byte_ptr_B,
ldm_B,
extent_B,
tb_offset_B,
weight_scale_ptr,
tb_offset_scale,
quant_args);
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
// Compute position within threadblock
int thread_idx = threadIdx.x;
@@ -919,21 +914,20 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
ptr_B,
extent_B,
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
thread_idx,
tb_offset_B);
MmaQuantArguments mma_quant_args = prepare_quant_args(
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
@@ -956,7 +950,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
accumulators,
iterator_A,
iterator_B,
mma_quant_args,
tile_dequanter_B,
accumulators);
//

View File

@@ -205,7 +205,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
threadblock_count,
epilogue_op,
reinterpret_cast<const ElementType*>(A),
reinterpret_cast<const CutlassMmaKernelType*>(B),
reinterpret_cast<const CutlassMmaWeightType*>(B),
reinterpret_cast<const ElementType*>(weight_scales),
reinterpret_cast<const ElementType*>(biases),
reinterpret_cast<ElementType*>(C),

View File

@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
iterator_C_.clear_mask();
}
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
// adding elementwise beta, we keep this here for future usage beta_ =
// adding elementwise beta, we keep this here for future useage beta_ =
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
// iterator_C_.clear_mask();

View File

@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:

View File

@@ -64,7 +64,7 @@ template <
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
/// Operation perfomed by GEMM
typename Operator,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.

View File

@@ -133,7 +133,7 @@ public:
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
static constexpr int kNumKIterationsPerWarpBLoad =

View File

@@ -509,7 +509,7 @@ public:
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
// TODO(wangbojun) lds_converter can be remove for int8 B input
// TOOD(wangbojun) lds_converter can be remove for int8 B input
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);

View File

@@ -96,7 +96,7 @@ public:
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
static constexpr int kNumKIterationsPerWarpBLoad =

View File

@@ -646,7 +646,7 @@ public:
// );
// }
}
// TODO(wangbojun) lds_converter can be remove for int8 B input
// TOOD(wangbojun) lds_converter can be remove for int8 B input
// int4
// typename TransformBAfterLDS::result_type converted_frag_B =
// lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);

View File

@@ -223,11 +223,14 @@ public:
static Status can_implement(Arguments const &args)
{
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()");
// printf("--1\n");
// Initialize static kernel and device properties, if necessary.
Status result = init_device_props();
// printf("--1-2\n");
if (result != Status::kSuccess) {
return result;
}
// printf("--2\n");
dim3 grid = get_grid_shape(args);
// printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z);
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
@@ -235,6 +238,7 @@ public:
{
return Status::kErrorInvalidProblem;
}
// printf("--3\n");
return GemmKernel::can_implement(args);
}
@@ -281,50 +285,18 @@ public:
}
/// Returns the maximum number of active thread blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
static int maximum_active_blocks()
{
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()");
int smem_size = int(sizeof(typename GemmKernel_::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10)) {
result = cudaFuncSetAttribute(Kernel2<GemmKernel_>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error "
<< cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
Kernel2<GemmKernel_>,
GemmKernel_::kThreadCount,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
// Initialize static device properties, if necessary
if (init_device_props() != Status::kSuccess) {
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
return sm_occupancy_;
}
@@ -369,7 +341,8 @@ public:
// Configure grid and block dimensions
dim3 block(GemmKernel::kThreadCount, 1, 1);
dim3 grid(params_.threadblock_count, 1, 1);
// dim3 grid = params_.get_grid_dims();
dim3 grid(216, 1, 1);
// Launch kernel
CUTLASS_TRACE_HOST(" "

View File

@@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log
rm -rf down_proj_8192_3584.log
num_experts=8
for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192
for tokens_per_expert in 12
do
wait
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
done
wait
echo "#### finish ####"

View File

@@ -996,6 +996,7 @@ int main(int argc, char *argv[]) {
CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
};
std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K};

View File

@@ -20,7 +20,7 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
int *mm_token_num_len,
int *seq_lens_this_time,
int *cu_seqlens_q,
float *hidden_states,
float *score_text,
float *output,
const int bsz,
const int hidden_size) {
@@ -32,11 +32,14 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
int max_seq_len_index_data = max_seq_len_index[0];
int mm_token_num_len_data = mm_token_num_len[0];
int true_bsz = cu_seqlens_q[bsz_index + 1] - 1;
if (bsz_index >= max_seq_len_index_data) {
true_bsz = true_bsz - mm_token_num_len_data;
}
if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) {
output[bsz_index * hidden_size + block_idx] = 0.0;
} else {
if (seq_lens_this_time[bsz_index] != 0) {
output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx];
output[bsz_index * hidden_size + block_idx] = score_text[true_bsz * hidden_size + block_idx];
}
}
__syncthreads();
@@ -48,19 +51,19 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor& mm_token_num_len,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& hidden_states) {
const paddle::Tensor& score_text) {
const int bsz = seq_lens_this_time.shape()[0];
const int hidden_size = hidden_states.shape()[1];
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place());
const int hidden_size = score_text.shape()[1];
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, score_text.place());
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, hidden_states.stream()>>>(
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, score_text.stream()>>>(
const_cast<int*>(max_seq_len.data<int>()),
const_cast<int*>(max_seq_len_index.data<int>()),
const_cast<int*>(mm_token_num_len.data<int>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(cu_seqlens_q.data<int>()),
const_cast<float*>(hidden_states.data<float>()),
const_cast<float*>(score_text.data<float>()),
output.data<float>(),
bsz,
hidden_size
@@ -73,9 +76,9 @@ std::vector<std::vector<int64_t>> ExtractTextTokenOutputInferShape(const std::ve
const std::vector<int64_t>& mm_token_num_len_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& hidden_states_shape) {
const std::vector<int64_t>& score_text_shape) {
const int bsz = seq_lens_this_time_shape[0];
const int hidden_size = hidden_states_shape[1];
const int hidden_size = score_text_shape[1];
return {{bsz, hidden_size}};
}
@@ -84,8 +87,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
const paddle::DataType& mm_token_num_len_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& hidden_states_dtype) {
return {hidden_states_dtype};
const paddle::DataType& score_text_dtype) {
return {score_text_dtype};
}
PD_BUILD_STATIC_OP(extract_text_token_output)
@@ -94,7 +97,7 @@ PD_BUILD_STATIC_OP(extract_text_token_output)
"mm_token_num_len",
"seq_lens_this_time",
"cu_seqlens_q",
"hidden_states"})
"score_text"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(ExtractTextTokenOutput))
.SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape))

View File

@@ -1,163 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "kernel_traits.h"
#include "flash_mask_attn_kernel.hpp"
template <typename paddle_type>
struct cuteType;
template <>
struct cuteType<phi::dtype::float16> {
using type = cutlass::half_t;
};
template <>
struct cuteType<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
};
template <typename T>
std::vector<paddle::Tensor> DispatchFlashAttentionMask(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::optional<paddle::Tensor>& mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int batch_size = cu_seq_q.dims()[0];
paddle::Tensor out = paddle::empty(
{q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place());
Flash_mask_params params;
memset(&params, 0, sizeof(Flash_mask_params));
params.q_ptr = const_cast<T*>(q_input.data<T>());
params.k_ptr = const_cast<T*>(k_input.data<T>());
params.v_ptr = const_cast<T*>(v_input.data<T>());
params.o_ptr = const_cast<T*>(out.data<T>());
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_len_q = max_enc_len_this_time;
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
using cute_type = typename cuteType<T>::type;
if (mask) {
params.mask = const_cast<int*>(mask.get().data<int>());
flash_attn_headdim128<kBlockM, kBlockN, true, cute_type>(params, 0);
} else {
flash_attn_headdim128<kBlockM, kBlockN, false, cute_type>(params, 0);
}
return {out};
}
std::vector<paddle::Tensor> FlashAttentionMask(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::optional<paddle::Tensor> &mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
return std::move(
DispatchFlashAttentionMask<T>(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time));
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
return std::move(
DispatchFlashAttentionMask<T>(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time));
}
}
PD_BUILD_OP(flash_attention_mask)
.Inputs({
"q_input",
"k_input",
"v_input",
"cu_seq_q",
"cu_seq_k",
"seq_len_encoder",
paddle::Optional("mask")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int"})
.Outputs({
"out"})
.SetKernelFn(PD_KERNEL(FlashAttentionMask));

View File

@@ -1,231 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "kernel_traits.h"
#include "mainloop_attn.hpp"
#include "softmax.hpp"
using namespace cute;
template <int kHeadDim>
auto get_gmem_layout(int token_num, int head_num) {
return make_layout(
make_shape(token_num, kHeadDim, head_num),
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
}
template <typename Ktraits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws(
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
CUTE_GRID_CONSTANT Flash_mask_params const data_params) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
constexpr int kHeadDim = Ktraits::kHeadDim;
constexpr bool NeedMask = Ktraits::NeedMask;
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
__align__(16) __shared__ int mask[kBlockM];
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
if constexpr (NeedMask) {
const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
mask[i] = mask_this_batch[i];
}
}
const int seq_len_q = data_params.seq_len_encoder[bidb];
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
if (m_block * kBlockM >= seq_len_q) {
return;
}
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1);
}
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
__syncthreads();
CollectiveMainloop collective_mainloop;
const int real_seq = seq_len_q - m_block * kBlockM;
const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_write_k,
smem_pipe_write_v,
shared_storage,
n_block_max,
m_block,
bidh,
bidb,
data_params.cu_seq_q,
data_params.cu_seq_k,
seq_len_q,
seq_len_k);
}
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
typename Ktraits::TiledMma1 tiled_mma1;
PipelineState smem_pipe_read_k, smem_pipe_read_v;
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
collective_mainloop.mma(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_read_k,
smem_pipe_read_v,
tOrO,
softmax,
mask,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
seq_len_q,
seq_len_k,
shared_storage);
const int o_head_stride = data_params.head_num * kHeadDim;
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
collective_mainloop.store<NumMmaThreads>(
mainloop_params,
tOrO,
shared_storage,
tiled_mma1,
threadIdx.x - NumCopyThreads,
o_head_stride,
real_seq,
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
}
}
template<typename Kernel_traits>
void run_flash_mask(Flash_mask_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num),
static_cast<Element const*>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
static_cast<Element const*>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
params.scale_softmax_log2
});
int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
void *kernel;
kernel = (void *)compute_attn_ws<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = num_blocks_m;
grid_dims.y = params.head_num;
grid_dims.z = params.batch_size;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
}
template <int kBlockM, int kBlockN, bool NeedMask, typename InputType>
void flash_attn_headdim128(Flash_mask_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
constexpr static int kNWarps = kBlockM / 16 + 4;
constexpr static int kStages = 2;
using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>;
run_flash_mask<Ktraits>(params, stream);
}

View File

@@ -1,124 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
struct Flash_mask_params {
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void * __restrict__ o_ptr;
int * __restrict__ cu_seq_q;
int * __restrict__ cu_seq_k;
int * __restrict__ mask;
int * seq_len_encoder;
int head_num;
int kv_head_num;
int max_seq_len_q;
int max_seq_len_k;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;
};
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
struct SharedStorageQKVO {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool NeedMask_, typename elem_type=cutlass::half_t>
struct Flash_mask_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int32_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
static constexpr int kStages = kStages_;
static constexpr int NeedMask = NeedMask_;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using GmemTiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{},
TiledCopyOThrLayout{},
TiledCopyOValLayout{}
));
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
};

View File

@@ -1,431 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.hpp"
using namespace cute;
template <typename Ktraits>
struct CollectiveMainloopAttn {
using Element = typename Ktraits::Element;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr bool NeedMask = Ktraits::NeedMask;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
using SmemLayoutO = typename Ktraits::SmemLayoutO;
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
LayoutT layout_Q;
Element const* ptr_K;
LayoutT layout_K;
Element const* ptr_V;
LayoutT layout_V;
float const softmax_scale_log2;
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{});
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.layout_Q, args.layout_K, args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
const MTensor &m_tensor,
const Shape &tile_shape,
const int *cu_seq_len,
const int bidh,
const int bidb,
const int actual_seq_len) const {
auto g_offset = local_tile(
m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb], _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()
));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
const int n_block_max,
const int m_block,
const int bidh,
const int bidb,
const int *cu_seq_q,
const int *cu_seq_k,
const int seq_len_q,
const int seq_len_k) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
Tensor gQ = get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
Tensor gK = get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor gV = get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
uint16_t mcast_mask_kv = 0;
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
}
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
if (lane_predicate) {
#pragma unroll 2
for (; n_block > 0; --n_block) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
const int *mask,
const int n_block_max,
const int thread_idx,
const int m_block,
const int seq_len_q,
const int seq_len_k,
SharedStorage& shared_storage) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int n_block = n_block_max - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
int mask_start_idx;
int mask_row_id;
int col_base;
if constexpr (NeedMask) {
const int lane_id = thread_idx % 32;
mask_start_idx = mask[0] / kBlockN - 1;
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
col_base = thread_idx % 4 * 2;
app_mask(
tSrS,
mask,
mask_row_id,
col_base + n_block * kBlockN);
} else {
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
};
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >=
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
#pragma unroll 1
for (; n_block > 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
if constexpr (NeedMask) {
if (n_block >= mask_start_idx) {
app_mask(
tSrS,
mask,
mask_row_id,
col_base + n_block * kBlockN);
}
}
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
}
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v);
++smem_pipe_read_v;
softmax.rescale_o(tOrO, scores_scale);
return;
}
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO const& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
const int o_head_stride,
const int real_seq,
T * out_ptr) {
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(o_head_stride, _1{}));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
if (real_seq >= kBlockM) {
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
} else {
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
}
}
};

View File

@@ -1,206 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.hpp"
using namespace cute;
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
}
__forceinline__ __device__ __half2 half_exp(__half2 x) {
uint32_t tmp_out, tmp_in;
tmp_in = reinterpret_cast<uint32_t&>(x);
asm ("ex2.approx.f16x2 %0, %1;\n"
: "=r"(tmp_out)
: "r"(tmp_in));
__half2 out = reinterpret_cast<__half2&>(tmp_out);
return out;
}
// Apply the exp to all the elements.
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const float max_scaled = max(mi) * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
CUTLASS_DEVICE Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
cute::fill(scores_scale, 1.f);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*zero_init=*/false>(scores, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale(mi);
}
}
return scores_scale;
};
template<bool Is_first, typename Tensor0>
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f);
} else {
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
}
return scores_scale;
};
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT scores_scale;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.0f / sum;
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
scores_scale(mi) = inv_sum;
}
return scores_scale;
};
template<typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
};

View File

@@ -1,453 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <fstream>
#include <iostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template<typename T>
struct PackedHalf;
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
};
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
};
template<typename T>
__forceinline__ __device__ auto float_2_half2(const float x) {
if constexpr (std::is_same<T, cutlass::half_t>::value) {
return __float2half2_rn(x);
} else {
return __float2bfloat162_rn(x);
}
}
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
struct uint8 {
uint4 u;
uint4 v;
};
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
inline __device__ Vec() {}
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr) {
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
}
inline __device__ void store_to(void *base_ptr) {
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
}
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
}
}
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
}
}
inline __device__ void set_zero() {
constexpr int size = sizeof(Vec_type) / sizeof(int);
#pragma unroll
for (int i = 0; i < size; ++i) {
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
}
}
};
template<typename T, int PackSize>
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
static_assert(PackSize % 2 == 0);
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
const float cos_inv_freq = cos.data.elt[i];
const float sin_inv_freq = sin.data.elt[i];
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
}
}
template <typename Tensor>
__forceinline__ __device__ void app_mask(
Tensor &tSrS,
const int *mask,
const int &mask_row_id,
const int &col_base) {
const float mask_value = -1000000.0f;
for (int i = 0; i < size(tSrS); i+=8) {
const int col = i * 2 + col_base;
if (col >= mask[mask_row_id]) {
tSrS(i) = mask_value;
}
if (col + 1 >= mask[mask_row_id]) {
tSrS(i + 1) = mask_value;
}
if (col >= mask[mask_row_id + 8]) {
tSrS(i + 2) = mask_value;
}
if (col + 1 >= mask[mask_row_id + 8]) {
tSrS(i + 3) = mask_value;
}
if (col + 8 >= mask[mask_row_id]) {
tSrS(i + 4) = mask_value;
}
if (col + 9 >= mask[mask_row_id]) {
tSrS(i + 5) = mask_value;
}
if (col + 8 >= mask[mask_row_id + 8]) {
tSrS(i + 6) = mask_value;
}
if (col + 9 >= mask[mask_row_id + 8]) {
tSrS(i + 7) = mask_value;
}
}
}
template<typename T>
struct HalfMax;
template<>
struct HalfMax<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("max.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMax<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("max.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<typename T>
struct HalfMin;
template<>
struct HalfMin<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("min.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMin<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("min.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
__forceinline__ __device__ void copy(
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
}
}
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template<typename T, typename ReductionOp, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) { result_broadcast = result; }
__syncthreads();
return result_broadcast;
}
template<typename T, int block_size>
__inline__ __device__ T BlockScanSum(T val) {
typedef cub::BlockScan<int, block_size> BlockScanT;
__shared__ typename BlockScanT::TempStorage temp_storage;
int aggregate;
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
__syncthreads();
return val;
}
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
template<typename T>
struct MinOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
};
template <>
struct MinOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
};
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
template<typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}
};
template<typename T, typename ReductionOp, int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
ReductionOp op;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}

Some files were not shown because too many files have changed in this diff Show More