Compare commits

..

16 Commits

Author SHA1 Message Date
chenjian
7b09611d6b [Bug fix] Fix batched token condition (#3565) 2025-08-23 11:55:53 +08:00
chenjian
606d9e9c2c [Feature] Support mixed deployment with adapter (#3517) 2025-08-21 18:19:01 +08:00
李泳桦
d18a637a17 [feat] add metrics for yiyan adapter (#3219)
* [feat] add metrics for yiyan adapter

* [fix] fix metrics num_requests_waiting and num_requests_running

* [fix] fix metrics gpu_cache_usage_perc

* [refactor] change where requests_number increases

* [chore] rename xxx_block_num as xxx_gpu_block_num, and update their values accordingly

* [chore] delete useless code
2025-08-21 16:58:10 +08:00
chenjian
6854506533 [Bug fix] Fix bug for d blocks not enough (#3479)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Fix bug for memory allocation

* Fix bug for D blocks not enough

* fix bug when d blocks not enough

* fix bug when d blocks not enough

* fix cache message recycle step

* fix cache message recycle step

* Fix step_idx recycle
2025-08-21 11:36:16 +08:00
chenjian
c487b62ee0 [Bug fix] Fix memory allocation (#3475)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Fix bug for memory allocation
2025-08-19 19:48:24 +08:00
chenjian
d2f6c3b998 [Bug fix] Fix bug for seq_len_encoder is 1 (#3467) 2025-08-19 15:21:32 +08:00
chenjian
aba94169dc [Feature] Support batched tokens for EP (#3415)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug
2025-08-18 11:43:36 +08:00
chenjian
3f86ae0007 fix cache messager bug when d restart (#3386) 2025-08-14 11:43:59 +08:00
chenjian
89177d881c [Bug fix] Fix zmq core bug (#3357)
* [Bug fix] Fix zmq core bug due to concurrently used by threads

* Fix zmq core bug due to concurrently used by threads
2025-08-13 20:24:39 +08:00
chenjian
7573802a88 [Feature] Support mtp ep in fd (#3340)
* [Optimize] Add metrics for analysing perf

* Fix bug in mtp
2025-08-11 21:49:44 +08:00
chenjian
110f33a530 [Bug fix] Test td cache messager (#3242)
* support disable cache task in decode node

* fix busg

* Update engine.py

* Update expert_service.py

* Update splitwise_connector.py

* Optimize log for debug

* Optimize log for debug

* fix bug

---------

Co-authored-by: ltd0924 <ltd0924@sina.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
2025-08-06 15:52:45 +08:00
chenjian
a4572a5e5d fix bug for pd step signal (#3230) 2025-08-06 10:41:52 +08:00
chenjian
a9d231c900 Fix bug for concurrently visit zmq (#3233) 2025-08-06 10:41:10 +08:00
ltd0924
b20ffe3697 [Feature] optimize expert parallel (#3196)
* optimize

* Update expert_service.py

* Update worker_process.py

* optimize
2025-08-05 17:34:24 +08:00
ltd0924
dcf9c2daff [Feature] Optimize prefix cache (#3208)
* [LLM] support ep

* Update worker_process.py

* Update expert_service.py

* Update worker_process.py

* format files

* optimize prefix cache

* optimize prefix cache

* optimize prefix cache

* pre commit format

* pre commit format

* pre commit format

* Update cache_messager.py
2025-08-05 17:13:11 +08:00
chenjian
9f9971844f [Feature] Support ep pd with external module (#3194)
* Support external module

* Support external module

* Support external module

* Support external module

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix bug

* fix bug

* fix bug

* merge

---------

Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
2025-08-04 20:32:41 +08:00
701 changed files with 13280 additions and 74694 deletions

View File

@@ -1,186 +0,0 @@
name: Accuracy Test
description: "Run Accuracy Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
accuracy_tests:
runs-on: [self-hosted, GPU-h20-1Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Base Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
popd
pushd tests/ce/accuracy_cases
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
export TEMPLATE=TOKEN_LOGPROB
export MODEL_SIZE=0.3B
TEST_EXIT_CODE=0
python gsm8k.py || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -1,229 +0,0 @@
name: Base Test
description: "Run Base Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
base_tests:
runs-on: [self-hosted, GPU-h20-1Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Base Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
check_service() {
local timeout=${1:-90}
local url="http://localhost:${FLASK_PORT}/wait_for_infer?timeout=${timeout}"
local resp
resp=$(curl -s -X POST "$url")
if echo "$resp" | grep -q "服务启动超时"; then
exit 8
fi
}
check_service 90
popd
pushd tests/ce/server
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
export TEMPLATE=TOKEN_LOGPROB
TEST_EXIT_CODE=0
python -m pytest -sv test_base_chat.py test_compare_top_logprobs.py test_logprobs.py test_params_boundary.py test_seed_usage.py test_stream.py test_evil_cases.py test_completions.py test_return_token_ids.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--early-stop-config\": \"{\\\"enable_early_stop\\\":true, \\\"window_size\\\":6, \\\"threshold\\\":0.93}\"}"
check_service 90
python -m pytest -sv test_repetition_early_stop.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }"
check_service 90
python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }"
check_service 90
python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_mtp.yaml\", \"--enable-logprob\": \"False\"}"
check_service 180
export TEMPLATE=TOKEN_NORMAL
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_sot.yaml\", \"--enable-logprob\": \"False\"}"
check_service 360
export TEMPLATE=TOKEN_NORMAL
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -22,22 +22,12 @@ on:
description: "Enable nightly build mode (e.g. add date suffix to version)"
required: false
type: string
default: "OFF"
default: "ON"
FD_VERSION:
description: "FastDeploy Package Version"
required: false
type: string
default: ""
PADDLEVERSION:
description: "Paddle Version Build Use"
required: false
type: string
default: ""
PADDLE_WHL_URL:
description: "Paddle Wheel Package URL"
required: false
type: string
default: ""
UPLOAD:
description: "Upload Package"
required: false
@@ -55,7 +45,6 @@ on:
jobs:
fd-build:
runs-on: [self-hosted, GPU-Build]
timeout-minutes: 240
outputs:
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
steps:
@@ -96,10 +85,6 @@ jobs:
compile_arch: ${{ inputs.COMPILE_ARCH }}
fd_version: ${{ inputs.FD_VERSION }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
BRANCH_REF: ${{ github.ref_name }}
PADDLEVERSION: ${{ inputs.PADDLEVERSION }}
PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }}
WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }}
run: |
set -x
runner_name="${{ runner.name }}"
@@ -124,9 +109,6 @@ jobs:
-e "COMPILE_ARCH=${compile_arch}" \
-e "FD_VERSION=${fd_version}" \
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
-e "PADDLEVERSION=${PADDLEVERSION}" \
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
-e "BRANCH_REF=${BRANCH_REF}" \
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
if [[ -n "${FD_VERSION}" ]]; then
export FASTDEPLOY_VERSION=${FD_VERSION}
@@ -134,7 +116,6 @@ jobs:
fi
git config --global --add safe.directory /workspace/FastDeploy
chown -R $(whoami) /workspace/FastDeploy
cd FastDeploy
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
@@ -143,16 +124,10 @@ jobs:
echo "Date Only: $DATE_ONLY"
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
fi
# 针对不同分支和tag使用不同的PaddlePaddle安装包
if [[ "${PADDLE_WHL_URL}" != "" ]];then
python -m pip install ${PADDLE_WHL_URL}
elif [[ "${PADDLEVERSION}" != "" ]];then
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
else
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
fi
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
pip config set install.trusted-host pip.baidu.com
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install --upgrade pip
python -m pip install -r requirements.txt

View File

@@ -68,7 +68,7 @@ jobs:
branch_name=${{ github.ref_name }}
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -O bos_tools.py -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls

View File

@@ -62,24 +62,18 @@ jobs:
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
last_char="${runner_name: -1}"
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
if [[ "$last_char" =~ [0-7] ]]; then
DEVICES="$last_char"
else
DEVICES="0"
fi
FLASK_PORT=$((9160 + DEVICES * 100))
FD_API_PORT=$((9180 + DEVICES * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + DEVICES * 100))
FD_METRICS_PORT=$((9170 + DEVICES * 100))
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
@@ -91,52 +85,28 @@ jobs:
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
PARENT_DIR=$(dirname "$WORKSPACE")
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
docker run --ipc=host --pid=host --net=host \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c '
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
python -m pip install paddlepaddle-gpu==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
pip config set install.trusted-host pip.baidu.com
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
@@ -154,10 +124,6 @@ jobs:
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
curl -s -o /dev/null -w "%{http_code}" -m 2 "http://0.0.0.0:${FD_API_PORT}/health"
curl -X POST "http://0.0.0.0:${FD_API_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}"
set +e
rm -rf ./baseline_output
cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output

View File

@@ -21,16 +21,14 @@ on:
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
concurrency:
group: ${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
run_ce_cases:
runs-on: [self-hosted, PRE_CE_RUN_2Card]
timeout-minutes: 60
runs-on: [self-hosted, GPU-L20-4Card]
steps:
- name: Print current runner name
run: |
@@ -69,80 +67,38 @@ jobs:
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
last_char="${runner_name: -1}"
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
if [ "${last_char}" = "1" ]; then
gpu_id=2
DEVICES="2,3"
else
gpu_id=0
DEVICES="0,1"
fi
FD_API_PORT=$((9180 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
FD_METRICS_PORT=$((9170 + gpu_id * 100))
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \
-v "/ssd4/GithubActions/ModelData:/ModelData:ro" \
-v "/ssd4/GithubActions/CacheDir:/root/.cache" \
-v "/ssd4/GithubActions/ConfigDir:/root/.config" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "fd_wheel_url=${fd_wheel_url}" \
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c '
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
python -m pip install paddlepaddle-gpu==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
python -m pip install ${fd_wheel_url}
bash scripts/run_pre_ce.sh
'

View File

@@ -1,170 +0,0 @@
name: Stable Test
description: "Run Stable Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
stable_tests:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Stable Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42038 + DEVICE_PORT * 100))
FD_INFERENCE_MSG_QUEUE_ID=$(( 42048 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
TEST_EXIT_CODE=0
pushd tests/ce/stable_cases
bash launch_model.sh /MODELDATA
bash run.sh || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -1,4 +1,4 @@
name: Coverage Check
name: Run FastDeploy Unit Tests and Coverage
description: "Run FastDeploy Unit Tests and Coverage"
on:
@@ -22,28 +22,10 @@ on:
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
secrets:
github-token:
required: true
jobs:
check_cov_skip:
uses: ./.github/workflows/check-bypass.yml
secrets:
github-token: ${{ secrets.github-token }}
with:
workflow-name: coverage
run_tests_with_coverage:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 60
needs: check_cov_skip
if: needs.check_cov_skip.outputs.can-skip != 'true'
outputs:
diff_cov_file_url: ${{ steps.cov_upload.outputs.diff_cov_file_url }}
unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }}
@@ -85,128 +67,65 @@ jobs:
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
BASE_REF: ${{ github.event.pull_request.base.ref }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
IS_PR: ${{ github.event_name == 'pull_request' }}
run: |
if [[ "$IS_PR" == "true" ]]; then
echo "Running on PR"
else
echo "Not a PR"
fi
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
set -x
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
docker run --rm --net=host \
--cap-add=SYS_PTRACE --privileged --shm-size=64G \
-v $(pwd):/workspace -w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
-e "fd_wheel_url=${fd_wheel_url}" \
-e "BASE_REF=${BASE_REF}" \
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
python -m pip install paddlepaddle-gpu==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
pip config set install.trusted-host pip.baidu.com
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --net=host \
--name ${runner_name} \
--cap-add=SYS_PTRACE --shm-size=64G \
-v $(pwd):/workspace -w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e TZ="Asia/Shanghai" \
-e "fd_wheel_url=${fd_wheel_url}" \
-e "BASE_REF=${BASE_REF}" \
-e "IS_PR=${IS_PR}" \
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install coverage
python -m pip install diff-cover
python -m pip install pytest-cov
python -m pip install jsonschema aistudio_sdk==0.3.5
python -m pip install ${fd_wheel_url}
rm -rf fastdeploy
# coverage subprocess use
python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy
export PYTHONPATH=/workspace/FastDeploy/
if [ -d "tests/plugins" ]; then
cd tests/plugins
python setup.py install
cd ../..
else
echo "Warning: tests/plugins directory not found, skipping setup.py install"
fi
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
TEST_EXIT_CODE=0
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
coverage combine coveragedata/ || echo "No data to combine"
coverage report
coverage xml -o python_coverage_all.xml
COVERAGE_EXIT_CODE=0
if [[ "$IS_PR" == "true" ]]; then
python -m pip install coverage
python -m pip install diff-cover
python -m pip install ${fd_wheel_url}
if [ -d "test/plugins" ]; then
cd test/plugins
python setup.py install
cd ../..
else
echo "Warning: test/plugins directory not found, skipping setup.py install"
fi
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
TEST_EXIT_CODE=0
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
coverage combine coveragedata/
coverage xml -o python_coverage_all.xml
COVERAGE_EXIT_CODE=0
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
else
echo "Not a PR, skipping diff-cover"
fi
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
'
if [ -f FastDeploy/exit_code.env ]; then
cat FastDeploy/exit_code.env >> $GITHUB_ENV
fi
'
if [ -f FastDeploy/exit_code.env ]; then
cat FastDeploy/exit_code.env >> $GITHUB_ENV
fi
- name: Upload unit resule and diff coverage to bos
id: cov_upload
shell: bash
@@ -215,7 +134,7 @@ jobs:
commit_id=${{ github.event.pull_request.head.sha }}
pr_num=${{ github.event.pull_request.number }}
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py -O bos_tools.py
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
diff_cov_file="diff_coverage.xml"
@@ -234,7 +153,7 @@ jobs:
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
fi
unittest_result="failed_tests.log"
unittest_result="test/failed_tests.log"
if [ -s ${unittest_result} ];then
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
target_path_stripped="${target_path#paddle-github-action/}"
@@ -242,41 +161,32 @@ jobs:
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
fi
- name: Check Unit Test Success
- name: Determine Unit Succ and whether the coverage rate reaches 80%
shell: bash
run: |
cd FastDeploy
if [ "$TEST_EXIT_CODE" -eq 8 ]; then
filename=$(basename "$unittest_failed_url")
if [ -z "${unittest_failed_url}" ]; then
echo "No diff unit failed file URL provided."
else
rm -rf "${filename}"
wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
wget ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
fi
echo "Unit tests failed (exit code 8)"
filename=$(basename "$unittest_failed_url")
if [ -f "${filename}" ];then
echo "Failed test cases:"
cat "${filename}"
fi
exit "$TEST_EXIT_CODE"
fi
echo "All tests passed"
- name: Verify Code Coverage Threshold (80%)
if: ${{ github.event_name == 'pull_request' }}
shell: bash
run: |
cd FastDeploy
if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then
echo "Coverage generation failed (exit code 9)"
filename=$(basename "$diff_cov_result_json_url")
if [ -z "${diff_cov_result_json_url}" ]; then
echo "No diff cov result file URL provided."
else
rm -rf "${filename}"
wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
wget ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
fi
filename=$(basename "$diff_cov_result_json_url")
if [ -f "${filename}" ];then
echo "Failed test cases:"
if command -v jq >/dev/null 2>&1; then
@@ -287,24 +197,19 @@ jobs:
fi
exit "$COVERAGE_EXIT_CODE"
fi
echo "coverage passed"
echo "All tests and coverage passed"
exit 0
diff_coverage_report:
needs: run_tests_with_coverage
if: always()
runs-on: ubuntu-latest
env:
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
steps:
- name: coverage diff file download
shell: bash
env:
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
run: |
wget ${fd_archive_url}
tar -xf FastDeploy.tar.gz
cd FastDeploy
if [ -z "${diff_cov_file_url}" ]; then
echo "No diff coverage file URL provided."
exit 0
@@ -314,9 +219,6 @@ jobs:
if: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url != null && needs.run_tests_with_coverage.outputs.diff_cov_file_url != '' }}
uses: codecov/codecov-action@v5
with:
files: ./FastDeploy/diff_coverage.xml
files: ./diff_coverage.xml
name: python diff coverage
verbose: true
disable_search: true
commit_parent: false
flags: diff

View File

@@ -6,9 +6,6 @@ on:
- develop
- 'release/*'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:
Approval:
name: Approval

View File

@@ -1,248 +0,0 @@
name: CE Compile Job
on:
workflow_dispatch:
push:
branches:
- develop
- 'release/*'
permissions: read-all
concurrency:
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
ce_job_pre_check:
runs-on: ubuntu-latest
env:
COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }}
CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
outputs:
branch_match: ${{ steps.set_output.outputs.branch_match }}
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
sm8689_match: ${{ steps.set_output.outputs.sm8689_match }}
sm8090_match: ${{ steps.set_output.outputs.sm8090_match }}
steps:
- name: Set Version
id: set_output
env:
COMPILE_BRANCH: ${{ env.COMPILE_BRANCH }}
CE_COMPILE_SELECTION: ${{ env.CE_COMPILE_SELECTION }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ env.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
GITHUB_REF_NAME: ${{ github.ref_name }}
run: |
# 选择要触发编译任务的分支 done
# 选择指定分支要编译的任务 8090或者8689
# 指定分支编译要使用的Paddle的安装包,默认使用nightly最新的
IFS=',' read -ra BRANCHES <<< "$COMPILE_BRANCH"
MATCH=false
for b in "${BRANCHES[@]}"; do
if [[ "$b" == "${GITHUB_REF_NAME}" ]]; then
MATCH=true
break
fi
done
echo "branch_match=$MATCH" >> $GITHUB_OUTPUT
# 通过变量CE_COMPILE_SELECTION中的映射关系,决定分支是编译sm8090还是sm8689
for pair in $(echo "$CE_COMPILE_SELECTION" | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
compile_task_list=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "$GITHUB_REF_NAME" ]]; then
# 判断里面是否包含 sm8090 或 sm8689
if [[ "$compile_task_list" == *"sm8090"* ]]; then
echo "sm8090_match=true" >> $GITHUB_OUTPUT
fi
if [[ "$compile_task_list" == *"sm8689"* ]]; then
echo "sm8689_match=true" >> $GITHUB_OUTPUT
fi
break
fi
done
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
FOUND_PADDLE_URL="$paddle_whl_url"
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
break
fi
done
print_ce_job_pre_check_outputs:
runs-on: ubuntu-latest
needs: ce_job_pre_check
steps:
- name: Print outputs as JSON
run: |
echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}'
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
needs: ce_job_pre_check
if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }}
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request'
&& github.event.pull_request.base.ref
|| github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
build_sm8090:
name: BUILD_SM8090
needs: [clone, ce_job_pre_check]
if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }}
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "80,90"
WITH_NIGHTLY_BUILD: OFF
FD_VERSION: 0.0.0
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
build_sm8689:
name: BUILD_SM8689
needs: [clone, ce_job_pre_check]
if: ${{ needs.ce_job_pre_check.outputs.sm8689_match == 'true' }}
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "86,89"
WITH_NIGHTLY_BUILD: OFF
FD_VERSION: 0.0.0
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
ce_upload_sm8090:
environment: CodeSync
name: CE_UPLOAD
needs: build_sm8090
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
COMPILE_ARCH: "80,90"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
echo "commit wheel url is ${WHEEL_PATH}"
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
ce_upload_sm8689:
environment: CodeSync
name: CE_UPLOAD
needs: build_sm8689
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
COMPILE_ARCH: "86,89"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
echo "commit wheel url is ${WHEEL_PATH}"
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
echo "latest wheel url is ${WHEEL_PATH_LATEST}"

View File

@@ -1,51 +0,0 @@
on:
workflow_call:
inputs:
workflow-name:
required: true
type: string
secrets:
github-token:
required: true
outputs:
can-skip:
description: "Whether the workflow can be skipped."
value: ${{ jobs.check-bypass.outputs.can-skip }}
jobs:
check-bypass:
name: Check bypass
runs-on: ubuntu-latest
permissions:
contents: read
env:
CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
outputs:
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
steps:
- name: Cleanup
run: |
rm -rf * .[^.]*
- id: check-bypass
name: Check Bypass
uses: PFCCLab/ci-bypass@v1
with:
github-token: ${{ secrets.github-token }}
non-pull-request-event-strategy: 'never-skipped'
type: 'composite'
composite-rule: |
{
"any": [
{
"type": "labeled",
"label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"],
"username": ${{ env.CI_TEAM_MEMBERS }}
},
{
"type": "commented",
"comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"],
"username": ${{ env.CI_TEAM_MEMBERS }}
}
]
}

20
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
name: CI
on:
pull_request:
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Print current runner name
run: |
echo "The current CI tasks have been migrated to PreCe and will be deprecated soon."

View File

@@ -13,8 +13,7 @@ concurrency:
jobs:
CI_GCU:
runs-on:
group: GCU
runs-on: [self-hosted, GCU-S60-8Card]
steps:
- name: Print current runner name
run: |
@@ -29,9 +28,7 @@ jobs:
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace \
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
-w /workspace \
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
@@ -42,7 +39,6 @@ jobs:
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
source ${{ github.workspace }}/../../../proxy
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
@@ -53,9 +49,6 @@ jobs:
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
echo "Copy models..."
sudo mkdir -p ci_models && sudo cp -r /work/deps/ERNIE-4.5-21B-A3B-Paddle ci_models
echo "Copy models done."
- name: Run CI unittest
env:
@@ -77,21 +70,19 @@ jobs:
echo "PARENT_DIR:$PARENT_DIR"
echo "Install drivers..."
cd /work/deps
sudo bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
cd -
echo "Create docker..."
docker run --rm --network=host --ipc=host --privileged \
-v $(pwd):/workspace \
-v /home:/home \
-v /work:/work \
-w /workspace \
-e "MODEL_PATH=./ci_models" \
docker run --rm --network=host --ipc=host -it --privileged \
-v $(pwd):/workspace -w /workspace \
-v "/home:/home" \
-v "/work:/work" \
-e "MODEL_PATH=/work/models" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
${docker_image} /bin/bash -c "
${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci_gcu.sh

View File

@@ -11,8 +11,7 @@ concurrency:
jobs:
CI_ILUVATAR:
runs-on:
group: IXUCA
runs-on: [self-hosted, IXUCA]
steps:
- name: Print current runner name
run: |

View File

@@ -24,7 +24,7 @@ jobs:
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
@@ -55,7 +55,7 @@ jobs:
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
@@ -77,7 +77,6 @@ jobs:
-e "MODEL_PATH=/ssd3/model" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \

View File

@@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang
- name: Deploy to GitHub Pages
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -19,7 +19,7 @@ jobs:
needs: clone
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "89,90"
WITH_NIGHTLY_BUILD: "OFF"
@@ -39,59 +39,25 @@ jobs:
needs: [clone,build]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelCache"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

View File

@@ -1,321 +0,0 @@
name: Publish Job
on:
workflow_dispatch:
schedule:
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
push:
# branches:
# - develop
tags:
- '*'
permissions: read-all
concurrency:
group: ${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
publish_pre_check:
runs-on: ubuntu-latest
if: |
github.event.repository.fork == false &&
(
(github.event_name == 'schedule' && github.ref_name == 'develop') ||
(github.event_name == 'push' && github.ref_type == 'tag') ||
((github.event_name == 'workflow_dispatch') &&
(github.ref_name == 'develop' || github.ref_type == 'tag'))
)
env:
TAG_VERSION_MAPPINGS: ${{ vars.TAG_VERSION_MAPPINGS }}
FD_VERSION_DEV: ${{ vars.FD_VERSION_DEV }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
outputs:
compile_use_paddle_version: ${{ steps.set_output.outputs.compile_use_paddle_version }}
compile_continue: ${{ steps.set_output.outputs.compile_continue }}
fd_version: ${{ steps.set_output.outputs.fd_version }}
with_nightly_build: ${{ steps.set_output.outputs.with_nightly_build }}
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
steps:
- name: Get tag version
if: github.ref_type == 'tag'
run: |
TAG_NAME="${GITHUB_REF##*/}" # 提取 tag 名称,比如 v2.1.0
TAG_VERSION="${TAG_NAME#v}" # 去掉前缀 v
echo "FD_VERSION=$TAG_VERSION" >> $GITHUB_ENV
- name: Check FD version to Paddle version mapping
if: github.ref_type == 'tag'
env:
TARGET_FD: ${{ env.FD_VERSION }}
run: |
FOUND_PADDLE=""
# 遍历映射
for pair in $(echo $TAG_VERSION_MAPPINGS | tr ';' ' '); do
fd=$(echo "$pair" | cut -d',' -f1)
paddle=$(echo "$pair" | cut -d',' -f2)
if [[ "$fd" == "$TARGET_FD" ]]; then
FOUND_PADDLE="$paddle"
break
fi
done
if [[ -z "$FOUND_PADDLE" ]]; then
echo "No Paddle version found for FD $TARGET_FD"
else
echo "FD $TARGET_FD maps to Paddle $FOUND_PADDLE"
echo "PADDLE_VERSION=$FOUND_PADDLE" >> $GITHUB_ENV
fi
- name: Set Version
id: set_output
env:
PADDLE_VERSION: ${{ env.PADDLE_VERSION }}
FD_VERSION: ${{ env.FD_VERSION }}
run: |
if [[ "${{ github.ref_type }}" == "tag" ]]; then
if [[ -z "$PADDLE_VERSION" ]]; then
compile_continue=false
else
compile_use_paddle_version=$PADDLE_VERSION
compile_continue=true
fi
fd_version=$FD_VERSION
fi
if [[ "${{ github.ref_name }}" == "develop" ]];then
compile_continue=true
compile_use_paddle_version=""
fd_version=${FD_VERSION_DEV}
with_nightly_build=ON
fi
# Todo
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
FOUND_PADDLE_URL="$paddle_whl_url"
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
compile_continue=true
break
fi
done
echo "compile_continue=${compile_continue}" >> $GITHUB_OUTPUT
echo "compile_use_paddle_version=${compile_use_paddle_version}" >> $GITHUB_OUTPUT
echo "fd_version=${fd_version}" >> $GITHUB_OUTPUT
echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT
print_publish_pre_check_outputs:
runs-on: ubuntu-latest
needs: publish_pre_check
steps:
- name: Print outputs as JSON
run: |
echo '${{ toJSON(needs.publish_pre_check.outputs) }}'
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
needs: publish_pre_check
if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }}
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
if [[ "${{ github.ref_type }}" == "tag" ]]; then
commit_id=${{ github.sha }}
tag_name=${{ github.ref_name }}
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
else
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
build_sm8090:
name: BUILD_SM8090
needs: [clone, publish_pre_check]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "80,90"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
build_sm8689:
name: BUILD_SM8689
needs: [clone, publish_pre_check]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "86,89"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
paddle_pypi_upload_sm8090:
environment: PaddleSourceUpload
name: PADDLE_PYPI_UPLOAD_8090
needs: build_sm8090
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
COMPILE_ARCH: "80,90"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
if: github.ref_name == 'develop' || github.ref_type == 'tag'
run: |
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
if [[ "${{ github.ref_name }}" == "develop" ]];then
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
else
echo "Not develop or tag, do nothing"
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
paddle_pypi_upload_sm8689:
environment: PaddleSourceUpload
name: PADDLE_PYPI_UPLOAD_8689
needs: build_sm8689
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
COMPILE_ARCH: "86,89"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
if: github.ref_name == 'develop' || github.ref_type == 'tag'
run: |
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
if [[ "${{ github.ref_name }}" == "develop" ]];then
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
else
echo "Not develop or tag, do nothing"
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build_sm8090]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build_sm8090]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build_sm8090]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

11
.gitignore vendored
View File

@@ -121,7 +121,7 @@ dmypy.json
FETCH_HEAD
#log
log/
log*/
checkpoints/
checkpoints_origin/
@@ -159,9 +159,6 @@ custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
#marlin_kernel
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
#machete_kernel
custom_ops/gpu_ops/machete/generated
# buff
custom_ops/tmp*
@@ -170,9 +167,3 @@ build
.ccls-cache
third_party
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h

9
.gitmodules vendored
View File

@@ -1,9 +0,0 @@
[submodule "custom_ops/third_party/DeepGEMM"]
path = custom_ops/third_party/DeepGEMM
url = https://github.com/deepseek-ai/DeepGEMM.git
[submodule "custom_ops/third_party/cutlass"]
path = custom_ops/third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "custom_ops/third_party/nlohmann_json"]
path = custom_ops/third_party/nlohmann_json
url = https://github.com/nlohmann/json.git

View File

@@ -1,4 +1,3 @@
English | [简体中文](README_CN.md)
<p align="center">
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
</p>
@@ -23,12 +22,11 @@ English | [简体中文](README_CN.md)
</p>
--------------------------------------------------------------------------------
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
## News
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -52,16 +50,14 @@ English | [简体中文](README_CN.md)
## Installation
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions:
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md)
- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md)
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md.md)
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
## Get Started
@@ -71,12 +67,20 @@ Learn how to use FastDeploy through our documentation:
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
- [Offline Inference Development](./docs/offline_inference.md)
- [Online Service Deployment](./docs/online_serving/README.md)
- [Best Practices](./docs/best_practices/README.md)
- [Full Supported Models List](./docs/supported_models.md)
- [Optimal Deployment](./docs/optimal_deployment/README.md)
## Supported Models
Learn how to download models, enable using the torch format, and more:
- [Full Supported Models List](./docs/supported_models.md)
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K |
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
## Advanced Usage

View File

@@ -1,89 +0,0 @@
[English](README.md) | 简体中文
<p align="center">
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
</p>
<p align="center">
<a href=""><img src="https://img.shields.io/badge/python-3.10-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux-pink.svg"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/FastDeploy?color=9ea"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/"><b> 安装指导 </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/quick_start"><b> 快速入门 </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/zh/supported_models/"><b> 支持模型列表 </b></a>
</p>
--------------------------------------------------------------------------------
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
## 最新活动
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容性能进一步优化更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略更多模型支持PD分离和CUDA Graph昆仑、海光等更多硬件支持增强全方面优化服务和推理引擎的性能。
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
## 关于
**FastDeploy** 是基于飞桨PaddlePaddle的大语言模型LLM与视觉语言模型VLM推理部署工具包提供**开箱即用的生产级部署方案**,核心技术特性包括:
- 🚀 **负载均衡式PD分解**工业级解决方案支持上下文缓存与动态实例角色切换在保障SLO达标和吞吐量的同时优化资源利用率
- 🔄 **统一KV缓存传输**轻量级高性能传输库支持智能NVLink/RDMA选择
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
- 🧮 **全量化格式支持**W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
-**高级加速技术**推测解码、多令牌预测MTP及分块预填充
- 🖥️ **多硬件支持**NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
## 要求
- 操作系统: Linux
- Python: 3.10 ~ 3.12
## 安装
FastDeploy 支持在**英伟达NVIDIAGPU**、**昆仑芯KunlunxinXPU**、**天数IluvatarGPU**、**燧原EnflameGCU**、**海光HygonDCU** 以及其他硬件上进行推理部署。详细安装说明如下:
- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md)
- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md)
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md.md)
**注意:** 我们正在积极拓展硬件支持范围。目前包括昇腾AscendNPU 等其他硬件平台正在开发测试中。敬请关注更新!
## 入门指南
通过我们的文档了解如何使用 FastDeploy
- [10分钟快速部署](./docs/zh/get_started/quick_start.md)
- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md)
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
- [离线推理](./docs/zh/offline_inference.md)
- [在线服务](./docs/zh/online_serving/README.md)
- [最佳实践](./docs/zh/best_practices/README.md)
## 支持模型列表
通过我们的文档了解如何下载模型如何支持torch格式等
- [模型支持列表](./docs/zh/supported_models.md)
## 进阶用法
- [量化](./docs/zh/quantization/README.md)
- [分离式部署](./docs/zh/features/disaggregated.md)
- [投机解码](./docs/zh/features/speculative_decoding.md)
- [前缀缓存](./docs/zh/features/prefix_caching.md)
- [分块预填充](./docs/zh/features/chunked_prefill.md)
## 致谢
FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。

View File

@@ -1,6 +0,0 @@
num_gpu_blocks_override: 1024
max_model_len: 8192
max_num_seqs: 64
data_parallel_size: 8
tensor_parallel_size: 1
enable_expert_parallel: True

View File

@@ -1,8 +0,0 @@
top_p: 0.95
temperature: 0.6
metadata:
min_tokens: 1
max_tokens: 65535
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 0

View File

@@ -1,10 +0,0 @@
reasoning-parser: ernie_x1
tool_call_parser: ernie_x1
tensor_parallel_size: 4
max_model_len: 65536
max_num_seqs: 128
enable_prefix_caching: True
enable_chunked_prefill: True
gpu_memory_utilization: 0.85
use_cudagraph: True
enable_custom_all_reduce: True

View File

@@ -34,6 +34,7 @@ EGG_DIR="fastdeploy.egg-info"
# custom_ops directory config
OPS_SRC_DIR="custom_ops"
OPS_TMP_DIR_BASE="tmp_base"
OPS_TMP_DIR="tmp"
# command line log config
@@ -70,20 +71,25 @@ function copy_ops(){
PY_VERSION="py${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
SYSTEM_VERSION=`${python} -c "import platform; print(platform.system().lower())"`
PROCESSOR_VERSION=`${python} -c "import platform; print(platform.processor())"`
WHEEL_BASE_NAME="fastdeploy_base_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
WHEEL_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then
DEVICE_TYPE="rocm"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "ROCM ops have been copy to fastdeploy"
echo -e "BASE and ROCM ops have been copy to fastdeploy"
return
fi
mkdir -p ../fastdeploy/model_executor/ops/base
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
if [ "$is_cuda" = "True" ]; then
DEVICE_TYPE="gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "CUDA ops have been copy to fastdeploy"
echo -e "BASE and CUDA ops have been copy to fastdeploy"
return
fi
@@ -106,8 +112,9 @@ function copy_ops(){
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
if [ "$if_corex" = "True" ]; then
DEVICE_TYPE="iluvatar-gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
echo -e "Iluvatar ops have been copy to fastdeploy"
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
return
fi
@@ -119,26 +126,20 @@ function copy_ops(){
return
fi
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
if [ "$is_maca" = "True" ]; then
DEVICE_TYPE="metax_gpu"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "MACA ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../
cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu
echo -e "CPU ops have been copy to fastdeploy"
echo -e "BASE and CPU ops have been copy to fastdeploy"
return
}
function build_and_install_ops() {
cd $OPS_SRC_DIR
export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy}
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..."
${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE}
find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \;
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
@@ -212,6 +213,7 @@ function cleanup() {
fi
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR
}

View File

@@ -84,6 +84,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
seq_length,
bsz);
return {x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k};
@@ -96,7 +97,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -105,6 +106,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
@@ -113,6 +115,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,11 +19,10 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <typename T>
void RebuildPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cu_seqlens_q_data,
const int *cum_offsets_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
@@ -41,12 +40,11 @@ void RebuildPaddingCPUImpl(T *output_data,
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue;
}
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int ori_token_idx =
bi * max_input_length - cum_offsets_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
@@ -56,7 +54,7 @@ void RebuildPaddingCPUImpl(T *output_data,
template <typename T>
void RebuildAppendPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cu_seqlens_q_data,
const int *cum_offsets_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
@@ -71,32 +69,30 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
int bi = ori_token_id / max_input_length;
if (seq_len_this_time_data[bi] == 0 ||
(seq_lens_decoder_data[bi] == 0 &&
seq_lens_encoder_data[bi] == 0)) {
continue;
}
seq_lens_encoder_data[bi] == 0)) {
continue;
}
int seq_id = 0;
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
}
}
std::vector<paddle::Tensor> RebuildPaddingCPU(
const paddle::Tensor &tmp_out,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu =
@@ -111,7 +107,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
int bsz = cum_offsets_cpu.shape()[0];
paddle::Tensor out;
if (output_padding_offset_cpu) {
@@ -132,7 +128,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
}
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *cum_offsets_data = cum_offsets_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
@@ -145,7 +141,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -158,7 +154,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -171,7 +167,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -190,7 +186,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -202,7 +198,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -211,10 +207,11 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -233,7 +230,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cu_seqlens_q_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
@@ -242,14 +239,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
int64_t bsz = cum_offsets_shape[0];
return {{bsz, dim_embed}};
}
}
std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cu_seqlens_q_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
@@ -259,7 +256,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
.Inputs({"tmp_out",
"cu_seqlens_q",
"cum_offsets",
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",

View File

@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
template <paddle::DataType D>
void AppendAttentionKernel(
std::vector<paddle::Tensor> AppendAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
@@ -60,7 +60,6 @@ void AppendAttentionKernel(
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
@@ -73,11 +72,7 @@ void AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
@@ -123,6 +118,27 @@ void AppendAttentionKernel(
} else {
qkv_out = qkv;
}
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
D,
qkv.place());
}
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
const paddle::Tensor& lambda_batch_ids,
@@ -207,10 +223,7 @@ void AppendAttentionKernel(
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
};
if (qkv_out_scales) {
@@ -326,10 +339,7 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
@@ -353,10 +363,7 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
}
}
@@ -385,6 +392,8 @@ void AppendAttentionKernel(
cudaStreamWaitEvent(main_stream, decoder_event);
}
}
return {fmha_out, qkv_out};
}
std::vector<paddle::Tensor> AppendAttention(
@@ -420,11 +429,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -459,60 +464,8 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
// template dtype generation
phi::DataType dtype_id;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dtype_id = phi::DataType::BFLOAT16;
break;
} else if (compute_dtype == "fp16") {
dtype_id = phi::DataType::FLOAT16;
break;
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
// fmha_out generation, rewrite from AppendAttentionKernel
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
} else{
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
dtype_id,
qkv.place());
}
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
@@ -534,7 +487,6 @@ std::vector<paddle::Tensor> AppendAttention(
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
@@ -547,11 +499,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
@@ -566,198 +514,35 @@ std::vector<paddle::Tensor> AppendAttention(
speculate_max_draft_token_num,
causal,
speculate_decoder);
};
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (dtype_id){
case phi::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
return {fmha_out};
}
case phi::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
return {fmha_out};
}
default:
PD_THROW(
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
return dispatch_by_template(bp16_dtype);
} else if (compute_dtype == "fp16") {
return dispatch_by_template(fp16_dtype);
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
break;
}
}
return {paddle::Tensor{}};
}
void AppendAttentionWithOutput(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
AppendAttnMetaData meta_data;
const auto& qkv_dims = qkv.dims();
const auto& key_cache_dims = key_cache.dims();
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
// TODO: trick method support c4, add attr head_dims in the future
if (cache_quant_type_str == "cache_int4_zp") {
meta_data.head_dims *= 2;
}
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
break;
}
case paddle::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
break;
}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dispatch_by_template(bp16_dtype);
break;
} else if (compute_dtype == "fp16") {
dispatch_by_template(fp16_dtype);
break;
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
@@ -791,11 +576,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -819,7 +600,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
}
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}};
return {{token_num, num_heads * head_dim}, qkv_shape};
}
std::vector<paddle::DataType> AppendAttentionInferDtype(
@@ -855,11 +636,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -878,148 +655,32 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
if (compute_dtype == "bf16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
return {paddle::DataType::INT8};
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::BFLOAT16};
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
}
} else if (compute_dtype == "fp16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
return {paddle::DataType::INT8};
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::FLOAT16};
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
}
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
}
}
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const std::vector<int64_t>& fmha_out_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
return {fmha_out_shape};
}
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::DataType& fmha_out_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
return {fmha_out_dtype};
}
PD_BUILD_STATIC_OP(append_attention)
.Inputs({"qkv",
"key_cache",
@@ -1053,15 +714,11 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
paddle::Optional("kv_signal_data")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
@@ -1075,71 +732,7 @@ PD_BUILD_STATIC_OP(append_attention)
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
PD_BUILD_STATIC_OP(append_attention_with_output)
.Inputs({"qkv",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
"fmha_out",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
paddle::Optional("qkv_bias"),
paddle::Optional("qkv_out_scales"),
paddle::Optional("cache_k_quant_scales"),
paddle::Optional("cache_v_quant_scales"),
paddle::Optional("cache_k_dequant_scales"),
paddle::Optional("cache_v_dequant_scales"),
paddle::Optional("cache_k_zp"),
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));

View File

@@ -43,7 +43,6 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -142,7 +141,6 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -181,7 +179,7 @@ __global__ void multi_query_append_attention_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -247,16 +245,12 @@ __global__ void multi_query_append_attention_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -412,8 +406,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -427,8 +419,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -511,7 +502,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -549,9 +540,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -619,15 +611,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -893,7 +882,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -951,7 +939,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1074,18 +1061,12 @@ void MultiQueryAppendAttention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
false,
@@ -1123,9 +1104,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1138,8 +1116,7 @@ void MultiQueryAppendAttention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1184,8 +1161,8 @@ void MultiQueryAppendAttention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1195,9 +1172,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1210,8 +1184,7 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
@@ -1235,8 +1208,8 @@ void MultiQueryAppendAttention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1253,14 +1226,14 @@ void MultiQueryAppendAttention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1271,8 +1244,8 @@ void MultiQueryAppendAttention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -48,7 +48,6 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -173,7 +172,6 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -250,7 +248,7 @@ __global__ void multi_query_append_attention_c4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -335,15 +333,12 @@ __global__ void multi_query_append_attention_c4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -510,8 +505,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -525,8 +518,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -635,7 +627,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -711,9 +703,10 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -795,15 +788,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1098,7 +1088,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1162,7 +1151,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1297,18 +1285,10 @@ void MultiQueryAppendC4Attention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1354,9 +1334,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1369,8 +1346,7 @@ void MultiQueryAppendC4Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1416,15 +1392,15 @@ void MultiQueryAppendC4Attention(
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1434,9 +1410,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1449,8 +1422,7 @@ void MultiQueryAppendC4Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1473,8 +1445,8 @@ void MultiQueryAppendC4Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1491,14 +1463,14 @@ void MultiQueryAppendC4Attention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1509,8 +1481,8 @@ void MultiQueryAppendC4Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -48,7 +48,6 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -180,7 +179,6 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -218,7 +216,7 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -302,15 +300,12 @@ __global__ void multi_query_append_attention_c8_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -479,8 +474,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -494,8 +487,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -609,7 +601,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -650,7 +642,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -736,16 +728,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -1066,7 +1054,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1124,7 +1111,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1268,17 +1254,10 @@ void MultiQueryAppendC8Attention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1339,9 +1318,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1354,8 +1330,7 @@ void MultiQueryAppendC8Attention(
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1402,8 +1377,8 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1413,9 +1388,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1428,8 +1400,7 @@ void MultiQueryAppendC8Attention(
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1447,8 +1418,8 @@ void MultiQueryAppendC8Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1465,14 +1436,14 @@ void MultiQueryAppendC8Attention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1483,8 +1454,8 @@ void MultiQueryAppendC8Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -905,15 +905,12 @@ template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
bool IS_SYSTEM = false>
__device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_idx_base,
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
const uint32_t kv_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
float (*s_frag)[num_frags_z][8]) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -927,21 +924,10 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
}
const bool out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
@@ -949,7 +935,6 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +

View File

@@ -18,142 +18,6 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadKVT cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int half_head_size = head_size / 2;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
const uint32_t ori_idx =
start_token_idx * hidden_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (hi < num_heads + kv_num_heads) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
out_vec[2 * i] = src_vec[2 * i];
out_vec[2 * i + 1] = src_vec[2 * i + 1];
}
}
if (hi < (num_heads + kv_num_heads)) { // q k
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) { // q
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else { // k
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
} else {
// quant + write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
if (hi < num_heads + kv_num_heads) {
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
}
}
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
@@ -335,9 +199,8 @@ __global__ void append_decode_cache_T_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -381,142 +244,6 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int rotary_dim,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec, right_vec;
LoadBiasT left_bias_vec, right_bias_vec;
LoadKVT left_cache_vec, right_cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_head_size = head_size / 2;
const int half_rotary_dim = rotary_dim / 2;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
// const int64_t offset = 2 * hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / half_hidden_size;
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
if (hi < num_heads && h_bias >= half_rotary_dim){
continue;
}
if (seq_lens_encoder[ori_bi] > 0) continue;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
uint32_t ori_idx_left =
start_token_idx * hidden_size + hi * head_size + h_bias;
uint32_t ori_idx_right = ori_idx_left + half_head_size;
if (hi < num_heads){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else if (hi < num_heads + kv_num_heads){
if (h_bias < half_rotary_dim){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else{
ori_idx_left = ori_idx_left + half_rotary_dim;
ori_idx_right = ori_idx_left + half_rotary_dim;
}
}
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
if (h_bias < half_rotary_dim){
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// rope
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
} else {
// write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
uint32_t tgt_idx_left =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
if (hi < num_heads + kv_num_heads) {
if (h_bias < half_rotary_dim) {
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}else{
tgt_idx_left = tgt_idx_left + half_rotary_dim;
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
} else {
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
}
}
}
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
@@ -539,8 +266,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
@@ -587,9 +313,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -657,8 +382,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int head_size,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
@@ -715,9 +439,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -789,8 +512,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -833,9 +555,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -912,11 +633,10 @@ __global__ void append_decode_cache_int8_rope_kernel(
const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
if constexpr (!is_scale_channel_wise) {
scale = __ldg(&cache_k_scale[kv_head_idx]);
}
@@ -1043,8 +763,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1094,10 +813,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -1190,11 +908,10 @@ __global__ void append_decode_cache_int8_rope_kernel(
const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
if constexpr (!is_scale_channel_wise) {
scale = __ldg(&cache_k_scales[kv_head_idx]);
}
@@ -1344,8 +1061,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1393,9 +1109,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -1476,11 +1191,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1650,8 +1364,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1711,10 +1424,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -1822,11 +1533,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -2045,8 +1755,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -2090,9 +1799,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -2166,11 +1874,10 @@ __global__ void append_decode_cache_int4_rope_kernel(
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -2347,8 +2054,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -2397,9 +2103,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -2486,11 +2191,10 @@ __global__ void append_decode_cache_int4_rope_kernel(
&out_scale_vec2);
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -2674,8 +2378,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -2722,9 +2425,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// dequant + add_bias + rope
@@ -2805,11 +2507,10 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx], &right_src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx + 8], &right_src_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx + 8],
@@ -3051,8 +2752,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d) {
const int kv_num_heads) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -3110,9 +2810,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
&right_out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// dequant + add_bias + rope
@@ -3221,11 +2920,10 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
&right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx + 8],

View File

@@ -15,69 +15,6 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
@@ -97,7 +34,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int rotary_dim,
const int block_size,
const int bsz,
const cudaStream_t& stream,
@@ -135,32 +71,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
if (rotary_dim < dim_head){
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}else{
append_decode_cache_T_neox_rope_kernel<T, PackSize>
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
@@ -178,9 +91,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -287,8 +198,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int8_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -311,8 +221,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -339,8 +248,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -363,8 +271,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
}
@@ -428,8 +335,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int4_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -454,8 +360,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -484,8 +389,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -510,8 +414,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
}
@@ -538,10 +441,7 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -558,28 +458,82 @@ void DecoderWriteCacheWithRoPEKernel(
const float* cos_emb =
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
const float* sin_emb;
int rotary_dim = dim_head;
if (rotary_embs) {
sin_emb =
use_neox_rotary_style
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < dim_head){
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
}
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
}
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_decode_cache_rope_qk_norm(
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -592,6 +546,12 @@ void DecoderWriteCacheWithRoPEKernel(
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
@@ -601,192 +561,84 @@ void DecoderWriteCacheWithRoPEKernel(
bsz,
stream,
use_neox_rotary_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
rotary_dim,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
}
@@ -815,10 +667,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -845,10 +694,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -874,10 +720,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -903,7 +746,4 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -40,6 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -33,8 +33,7 @@ __global__ void VariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadScaleT = AlignedVector<float, VecSize>;
@@ -65,7 +64,6 @@ __global__ void VariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias;
const int64_t base_idx = token_idx * 3 * hidden_size + bias_idx;
Load<int, VecSize>(&qkv[base_idx], &src_vec);
@@ -74,8 +72,8 @@ __global__ void VariableLengthRotaryKernel(
}
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
if (qkv_id < 2) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -117,8 +115,7 @@ __global__ void VariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -145,12 +142,11 @@ __global__ void VariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx = token_idx * 3 * hidden_size +
qkv_id * hidden_size + hi * last_dim + h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -181,8 +177,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadScaleT = AlignedVector<float, VecSize>;
@@ -216,7 +211,6 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int bias_idx_left =
qkv_id * full_hidden_size + hi * last_dim + h_bias;
const int bias_idx_right = bias_idx_left + half_lastdim;
@@ -231,8 +225,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], &left_out_scale_vec);
Load<float, VecSize>(&qkv_out_scales[bias_idx_right], &right_out_scale_vec);
if (qkv_id < 2) {
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -275,8 +269,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int64_t elem_cnt,
const int num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
@@ -304,7 +297,6 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int base_idx_left = token_idx * 3 * full_hidden_size +
qkv_id * full_hidden_size + hi * last_dim +
h_bias;
@@ -312,8 +304,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
@@ -366,7 +358,7 @@ __global__ void GQAVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
const int ori_bi = batch_id_per_token[token_idx];;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
@@ -375,7 +367,6 @@ __global__ void GQAVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t bias_idx = hi * last_dim + h_bias;
const int64_t base_idx = token_idx * offset + bias_idx;
Load<int, VecSize>(&qkv[base_idx], &src_vec);
@@ -384,8 +375,8 @@ __global__ void GQAVariableLengthRotaryKernel(
}
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -414,97 +405,6 @@ __global__ void GQAVariableLengthRotaryKernel(
}
}
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryQKNormKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps
) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head) * last_dim;
const int all_head_num = elem_cnt / last_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
}
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < q_num_head) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
}
}
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryKernel(
const T *qkv,
@@ -614,7 +514,6 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t bias_idx = hi * last_dim + h_bias;
const int64_t base_idx = token_idx * offset + bias_idx;
Load<int, VecSize>(&qkv[base_idx], &src_vec);
@@ -622,8 +521,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
float input_left = static_cast<float>(src_vec[2 * i]);
@@ -700,15 +599,14 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t bias_idx = hi * last_dim + h_bias;
const int64_t base_idx = token_idx * offset + bias_idx;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = qkv_biases ? static_cast<float>(src_vec[2 * i]+ bias_vec[2 * i]) : static_cast<float>(src_vec[2 * i]);
@@ -756,8 +654,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<int, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadScaleT = AlignedVector<float, VecSize>;
@@ -787,7 +684,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int bias_idx_left = hi * last_dim + h_bias;
const int bias_idx_right = bias_idx_left + half_lastdim;
const int base_idx_left =
@@ -802,8 +698,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], &left_out_scale_vec);
Load<float, VecSize>(&qkv_out_scales[bias_idx_right], &right_out_scale_vec);
if (hi < (q_num_head + kv_num_head)) {
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -849,8 +745,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
@@ -874,7 +769,6 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
const int base_idx_left =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -882,76 +776,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
const float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
}
}
template <typename T, int VecSize = 1>
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales,
const T *qkv_biases,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int head_dim,
const int rotary_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int rotary_dim_half = rotary_dim / 2;
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / rotary_dim_half;
const int h_bias = bias % rotary_dim_half;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
const int base_idx_left =
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
const int base_idx_right = base_idx_left + rotary_dim_half;
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
@@ -1686,8 +1512,7 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
} else {
VariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
@@ -1702,8 +1527,7 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
}
} else {
const float *cos_emb = rotary_emb;
@@ -1724,8 +1548,7 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
} else {
NeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
@@ -1740,72 +1563,11 @@ void rotary_qk_variable(
elem_nums,
head_num,
seq_len,
dim_head,
rope_3d);
dim_head);
}
}
}
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_norm_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
const QKV_TYPE *qkv_input, // qkv
const float *qkv_out_scales, // [3, num_head, dim_head]
const T *qkv_bias,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
const int num_heads,
const int kv_num_heads,
const int seq_len,
const int input_output_len,
const int dim_head,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false,
const float *q_norm_weight = nullptr,
const float *k_norm_weight = nullptr,
const float rms_norm_eps = 1e-6) {
int64_t elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
assert(dim_head == 128 && "dim_head must be 128");
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
<<<grid_size, Block_Size, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void gqa_rotary_qk_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
@@ -1823,7 +1585,6 @@ void gqa_rotary_qk_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const int rotary_dim,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false) {
@@ -1901,41 +1662,9 @@ void gqa_rotary_qk_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
dim_head);
} else {
if (rotary_dim < dim_head){
PD_CHECK((rotary_dim / 2) % PackSize == 0);
elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
if (use_neox_style) {
elem_nums /= 2;
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
rotary_emb + input_output_len * rotary_dim / 2,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rotary_dim,
rope_3d);
}else{
GQANeoxVariableLengthRotaryKernel<T, PackSize>
GQANeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
@@ -1951,9 +1680,7 @@ void gqa_rotary_qk_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
}
dim_head);
}
}
}

View File

@@ -46,32 +46,38 @@ void EncoderWriteCacheWithRopeKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
auto token_num = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;
bool is_scale_channel_wise = false;
int rotary_dim = head_dim;
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
is_scale_channel_wise = true;
}
if (rotary_embs){
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < head_dim){
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
}
}
}
if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
gqa_rotary_qk_norm_variable(
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -89,81 +95,31 @@ void EncoderWriteCacheWithRopeKernel(
head_dim,
stream,
use_neox_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
rope_3d);
} else {
PD_THROW(
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
} else {
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
rotary_dim,
stream,
use_neox_style,
rope_3d);
} else {
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
}
}
const uint32_t block_size = meta_data.block_size;
if (cache_quant_type_str == "none") {

View File

@@ -289,7 +289,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
}
if (max_just_dec_len_this_time > 0) {

View File

@@ -37,8 +37,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -63,7 +62,6 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -82,8 +80,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -120,7 +118,6 @@ void gqa_rotary_qk_split_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
@@ -149,8 +146,7 @@ void gqa_rotary_qk_split_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
dim_head);
}
template <typename T,
@@ -894,8 +890,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const std::string& cache_quant_type,
const bool rope_3d) {
const std::string& cache_quant_type) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -958,9 +953,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
if (token_num < kv_token_num) {

View File

@@ -43,7 +43,4 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -27,7 +27,6 @@ struct AppendAttnMetaData {
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
const int *mask_offset = nullptr;
};
__forceinline__ __host__ __device__ int div_up(int a, int b) {
@@ -431,9 +430,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 14) { \
constexpr size_t GROUP_SIZE = 14; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
@@ -478,9 +474,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
if (causal) { \
constexpr bool CAUSAL = true; \
__VA_ARGS__ \
} else { \
constexpr bool CAUSAL = false; \
__VA_ARGS__ \
}
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
@@ -566,37 +559,3 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
convert_int8(result, source);
}
}
constexpr int kWarpSize = 32;
template<typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
WelfordCombine1(b_m2, m2);
}
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}

View File

@@ -77,54 +77,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
const float quant_min_bound, const float out_linear_in_scale,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int max_partition_size, const int encoder_max_partition_size,
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);
void AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids_per_batch,
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &fmha_out,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &qkv_bias,
const paddle::optional<paddle::Tensor> &qkv_out_scales,
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_k_zp,
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
@@ -154,8 +107,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type,
const bool rope_3d);
const std::string &cache_quant_type);
std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
@@ -172,29 +124,11 @@ paddle::Tensor FusedExpertMoeFunc(
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
std::vector<paddle::Tensor> MacheteMMKernel(
paddle::Tensor const& A, paddle::Tensor const& B,
paddle::optional<paddle::Tensor> const& maybe_group_scales,
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
paddle::optional<paddle::Tensor> const& maybe_token_scales,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule);
std::vector<paddle::Tensor> MachetePrepackBKernel(
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str);
std::vector<std::string> MacheteSupportedSchedules(
std::string const& a_type_str, std::string const& b_type_str);
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
const bool group_moe, const bool topk_only_mode);
std::vector<paddle::Tensor>
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
@@ -254,8 +188,7 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums);
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
@@ -593,7 +526,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa);
@@ -676,7 +609,7 @@ void SpeculateVerify(
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -721,20 +654,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
const int max_draft_tokens);
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -751,10 +670,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
@@ -845,31 +762,15 @@ void SpeculateStepPaddle(
const int max_draft_tokens);
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int64_t seed);
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
const paddle::Tensor &top_k);
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
PYBIND11_MODULE(fastdeploy_ops, m) {
@@ -924,7 +825,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attention
*/
m.def("append_attention", &AppendAttention, "append attention function");
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
/**
* gqa_rope_write_cache.cu
* gqa_rope_write_cache
@@ -956,7 +856,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
py::arg("gating_output"), py::arg("gating_correction_bias"),
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
py::arg("topk_only_mode"), "moe export dispatch function");
/**
* moe/fused_moe/ep_moe_prefill_func.cu
@@ -986,27 +886,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("recv_expert_count"), py::arg("block_size"),
"per token per block quant");
#ifdef ENABLE_MACHETE
/*machete/machete_mm.cu
* machete_mm
*/
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
py::arg("maybe_schedule"),
"machete mm function");
/*machete/machete_prepack_B.cu
* machete_prepack_B
*/
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
/*machete/machete_supported_schedules.cu
* machete_supported_schedules
*/
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
#endif
/**
* moe/fused_moe/moe_topk_select.cu
* moe_topk_select
@@ -1220,7 +1099,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
@@ -1230,8 +1109,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
@@ -1247,12 +1124,4 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
m.def("rejection_top_p_sampling", &TopPSamplingReject, "rejection_top_p_sampling function");
m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function");
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
}

View File

@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
@@ -163,12 +163,3 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
PD_BUILD_STATIC_OP(all_reduce)
.Inputs({"inp",
"out"})
.Outputs({"new_out"})
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
.SetInplaceMap({{"out", "new_out"}})
.SetKernelFn(PD_KERNEL(all_reduce));

View File

@@ -6,8 +6,6 @@
// clang-format off
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
#include "helper.h"
// clang-format on
/*

View File

@@ -1,163 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "kernel_traits.h"
#include "flash_mask_attn_kernel.hpp"
template <typename paddle_type>
struct cuteType;
template <>
struct cuteType<phi::dtype::float16> {
using type = cutlass::half_t;
};
template <>
struct cuteType<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
};
template <typename T>
std::vector<paddle::Tensor> DispatchFlashAttentionMask(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::optional<paddle::Tensor>& mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int batch_size = cu_seq_q.dims()[0];
paddle::Tensor out = paddle::empty(
{q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place());
Flash_mask_params params;
memset(&params, 0, sizeof(Flash_mask_params));
params.q_ptr = const_cast<T*>(q_input.data<T>());
params.k_ptr = const_cast<T*>(k_input.data<T>());
params.v_ptr = const_cast<T*>(v_input.data<T>());
params.o_ptr = const_cast<T*>(out.data<T>());
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_len_q = max_enc_len_this_time;
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
using cute_type = typename cuteType<T>::type;
if (mask) {
params.mask = const_cast<int*>(mask.get().data<int>());
flash_attn_headdim128<kBlockM, kBlockN, true, cute_type>(params, 0);
} else {
flash_attn_headdim128<kBlockM, kBlockN, false, cute_type>(params, 0);
}
return {out};
}
std::vector<paddle::Tensor> FlashAttentionMask(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::optional<paddle::Tensor> &mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
return std::move(
DispatchFlashAttentionMask<T>(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time));
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
return std::move(
DispatchFlashAttentionMask<T>(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time));
}
}
PD_BUILD_OP(flash_attention_mask)
.Inputs({
"q_input",
"k_input",
"v_input",
"cu_seq_q",
"cu_seq_k",
"seq_len_encoder",
paddle::Optional("mask")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int"})
.Outputs({
"out"})
.SetKernelFn(PD_KERNEL(FlashAttentionMask));

View File

@@ -1,231 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "kernel_traits.h"
#include "mainloop_attn.hpp"
#include "softmax.hpp"
using namespace cute;
template <int kHeadDim>
auto get_gmem_layout(int token_num, int head_num) {
return make_layout(
make_shape(token_num, kHeadDim, head_num),
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
}
template <typename Ktraits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws(
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
CUTE_GRID_CONSTANT Flash_mask_params const data_params) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
constexpr int kHeadDim = Ktraits::kHeadDim;
constexpr bool NeedMask = Ktraits::NeedMask;
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
__align__(16) __shared__ int mask[kBlockM];
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
if constexpr (NeedMask) {
const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
mask[i] = mask_this_batch[i];
}
}
const int seq_len_q = data_params.seq_len_encoder[bidb];
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
if (m_block * kBlockM >= seq_len_q) {
return;
}
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1);
}
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
__syncthreads();
CollectiveMainloop collective_mainloop;
const int real_seq = seq_len_q - m_block * kBlockM;
const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_write_k,
smem_pipe_write_v,
shared_storage,
n_block_max,
m_block,
bidh,
bidb,
data_params.cu_seq_q,
data_params.cu_seq_k,
seq_len_q,
seq_len_k);
}
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
typename Ktraits::TiledMma1 tiled_mma1;
PipelineState smem_pipe_read_k, smem_pipe_read_v;
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
collective_mainloop.mma(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_read_k,
smem_pipe_read_v,
tOrO,
softmax,
mask,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
seq_len_q,
seq_len_k,
shared_storage);
const int o_head_stride = data_params.head_num * kHeadDim;
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
collective_mainloop.store<NumMmaThreads>(
mainloop_params,
tOrO,
shared_storage,
tiled_mma1,
threadIdx.x - NumCopyThreads,
o_head_stride,
real_seq,
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
}
}
template<typename Kernel_traits>
void run_flash_mask(Flash_mask_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num),
static_cast<Element const*>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
static_cast<Element const*>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
params.scale_softmax_log2
});
int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
void *kernel;
kernel = (void *)compute_attn_ws<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = num_blocks_m;
grid_dims.y = params.head_num;
grid_dims.z = params.batch_size;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
}
template <int kBlockM, int kBlockN, bool NeedMask, typename InputType>
void flash_attn_headdim128(Flash_mask_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
constexpr static int kNWarps = kBlockM / 16 + 4;
constexpr static int kStages = 2;
using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>;
run_flash_mask<Ktraits>(params, stream);
}

View File

@@ -1,124 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
struct Flash_mask_params {
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void * __restrict__ o_ptr;
int * __restrict__ cu_seq_q;
int * __restrict__ cu_seq_k;
int * __restrict__ mask;
int * seq_len_encoder;
int head_num;
int kv_head_num;
int max_seq_len_q;
int max_seq_len_k;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;
};
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
struct SharedStorageQKVO {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool NeedMask_, typename elem_type=cutlass::half_t>
struct Flash_mask_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int32_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
static constexpr int kStages = kStages_;
static constexpr int NeedMask = NeedMask_;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using GmemTiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{},
TiledCopyOThrLayout{},
TiledCopyOValLayout{}
));
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
};

View File

@@ -1,431 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.hpp"
using namespace cute;
template <typename Ktraits>
struct CollectiveMainloopAttn {
using Element = typename Ktraits::Element;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr bool NeedMask = Ktraits::NeedMask;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
using SmemLayoutO = typename Ktraits::SmemLayoutO;
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
LayoutT layout_Q;
Element const* ptr_K;
LayoutT layout_K;
Element const* ptr_V;
LayoutT layout_V;
float const softmax_scale_log2;
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{});
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.layout_Q, args.layout_K, args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
const MTensor &m_tensor,
const Shape &tile_shape,
const int *cu_seq_len,
const int bidh,
const int bidb,
const int actual_seq_len) const {
auto g_offset = local_tile(
m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb], _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()
));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
const int n_block_max,
const int m_block,
const int bidh,
const int bidb,
const int *cu_seq_q,
const int *cu_seq_k,
const int seq_len_q,
const int seq_len_k) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
Tensor gQ = get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
Tensor gK = get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor gV = get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
uint16_t mcast_mask_kv = 0;
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
}
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
if (lane_predicate) {
#pragma unroll 2
for (; n_block > 0; --n_block) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
const int *mask,
const int n_block_max,
const int thread_idx,
const int m_block,
const int seq_len_q,
const int seq_len_k,
SharedStorage& shared_storage) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int n_block = n_block_max - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
int mask_start_idx;
int mask_row_id;
int col_base;
if constexpr (NeedMask) {
const int lane_id = thread_idx % 32;
mask_start_idx = mask[0] / kBlockN - 1;
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
col_base = thread_idx % 4 * 2;
app_mask(
tSrS,
mask,
mask_row_id,
col_base + n_block * kBlockN);
} else {
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
};
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >=
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
#pragma unroll 1
for (; n_block > 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
if constexpr (NeedMask) {
if (n_block >= mask_start_idx) {
app_mask(
tSrS,
mask,
mask_row_id,
col_base + n_block * kBlockN);
}
}
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
}
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v);
++smem_pipe_read_v;
softmax.rescale_o(tOrO, scores_scale);
return;
}
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO const& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
const int o_head_stride,
const int real_seq,
T * out_ptr) {
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(o_head_stride, _1{}));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
if (real_seq >= kBlockM) {
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
} else {
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
}
}
};

View File

@@ -1,206 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.hpp"
using namespace cute;
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
}
__forceinline__ __device__ __half2 half_exp(__half2 x) {
uint32_t tmp_out, tmp_in;
tmp_in = reinterpret_cast<uint32_t&>(x);
asm ("ex2.approx.f16x2 %0, %1;\n"
: "=r"(tmp_out)
: "r"(tmp_in));
__half2 out = reinterpret_cast<__half2&>(tmp_out);
return out;
}
// Apply the exp to all the elements.
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const float max_scaled = max(mi) * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
CUTLASS_DEVICE Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
cute::fill(scores_scale, 1.f);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*zero_init=*/false>(scores, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale(mi);
}
}
return scores_scale;
};
template<bool Is_first, typename Tensor0>
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f);
} else {
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
}
return scores_scale;
};
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT scores_scale;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.0f / sum;
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
scores_scale(mi) = inv_sum;
}
return scores_scale;
};
template<typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
};

View File

@@ -1,453 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <fstream>
#include <iostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template<typename T>
struct PackedHalf;
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
};
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
};
template<typename T>
__forceinline__ __device__ auto float_2_half2(const float x) {
if constexpr (std::is_same<T, cutlass::half_t>::value) {
return __float2half2_rn(x);
} else {
return __float2bfloat162_rn(x);
}
}
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
struct uint8 {
uint4 u;
uint4 v;
};
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
inline __device__ Vec() {}
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr) {
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
}
inline __device__ void store_to(void *base_ptr) {
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
}
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
}
}
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
}
}
inline __device__ void set_zero() {
constexpr int size = sizeof(Vec_type) / sizeof(int);
#pragma unroll
for (int i = 0; i < size; ++i) {
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
}
}
};
template<typename T, int PackSize>
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
static_assert(PackSize % 2 == 0);
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
const float cos_inv_freq = cos.data.elt[i];
const float sin_inv_freq = sin.data.elt[i];
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
}
}
template <typename Tensor>
__forceinline__ __device__ void app_mask(
Tensor &tSrS,
const int *mask,
const int &mask_row_id,
const int &col_base) {
const float mask_value = -1000000.0f;
for (int i = 0; i < size(tSrS); i+=8) {
const int col = i * 2 + col_base;
if (col >= mask[mask_row_id]) {
tSrS(i) = mask_value;
}
if (col + 1 >= mask[mask_row_id]) {
tSrS(i + 1) = mask_value;
}
if (col >= mask[mask_row_id + 8]) {
tSrS(i + 2) = mask_value;
}
if (col + 1 >= mask[mask_row_id + 8]) {
tSrS(i + 3) = mask_value;
}
if (col + 8 >= mask[mask_row_id]) {
tSrS(i + 4) = mask_value;
}
if (col + 9 >= mask[mask_row_id]) {
tSrS(i + 5) = mask_value;
}
if (col + 8 >= mask[mask_row_id + 8]) {
tSrS(i + 6) = mask_value;
}
if (col + 9 >= mask[mask_row_id + 8]) {
tSrS(i + 7) = mask_value;
}
}
}
template<typename T>
struct HalfMax;
template<>
struct HalfMax<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("max.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMax<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("max.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<typename T>
struct HalfMin;
template<>
struct HalfMin<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("min.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMin<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("min.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
__forceinline__ __device__ void copy(
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
}
}
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template<typename T, typename ReductionOp, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) { result_broadcast = result; }
__syncthreads();
return result_broadcast;
}
template<typename T, int block_size>
__inline__ __device__ T BlockScanSum(T val) {
typedef cub::BlockScan<int, block_size> BlockScanT;
__shared__ typename BlockScanT::TempStorage temp_storage;
int aggregate;
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
__syncthreads();
return val;
}
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
template<typename T>
struct MinOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
};
template <>
struct MinOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
};
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
template<typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}
};
template<typename T, typename ReductionOp, int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
ReductionOp op;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}

View File

@@ -109,11 +109,11 @@ void GetOutputEp(const paddle::Tensor& x,
return;
}
void GetOutputEPStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
GetOutputEp(x, rank_id, wait_flag, 1);
}
void GetOutputEPDynamic(const paddle::Tensor& x,
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
@@ -125,11 +125,11 @@ PD_BUILD_STATIC_OP(get_output_ep)
.Attrs({"rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutputEPStatic));
.SetKernelFn(PD_KERNEL(GetOutputStatic));
PD_BUILD_STATIC_OP(get_output_ep_dynamic)
.Inputs({"x"})
.Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutputEPDynamic));
.SetKernelFn(PD_KERNEL(GetOutputDynamic));

View File

@@ -46,11 +46,7 @@ __global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
#ifdef PADDLE_WITH_HIP
batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset;
#else
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
#endif
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
@@ -105,6 +101,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
cum_offsets_out.data<int>(),
seq_length);
return {x_remove_padding,
cum_offsets_out,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num};
@@ -117,7 +114,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -126,6 +123,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
@@ -134,6 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k"})

View File

@@ -193,12 +193,6 @@ public:
typedef uint8_t data_t;
};
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
public:
typedef __nv_fp8_e4m3 DataType;
typedef paddle::float8_e4m3fn data_t;
};
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
@@ -515,7 +509,6 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
}
#ifndef PADDLE_WITH_HIP
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) {
uint32_t flag;
@@ -548,7 +541,7 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
"l"(flag_addr));
}
}
#endif
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,

View File

@@ -1,574 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import math
import os
import shutil
import sys
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass, fields
from functools import reduce
from typing import Optional, Union
import jinja2
cur_dir = os.path.dirname(os.path.abspath(__file__))
p = os.path.abspath(os.path.join(cur_dir, "../../third_party/cutlass/python"))
sys.path.insert(0, p)
from cutlass_library import (
EpilogueScheduleTag,
EpilogueScheduleType,
TileSchedulerTag,
TileSchedulerType,
)
# yapf conflicts with isort for this block
# yapf: disable
from machete_cutlass_library_extension import (
DataType,
MACHETEDataType,
MACHETEDataTypeMACHETEScalarTypeTag,
MACHETEDataTypeNames,
MACHETEDataTypePaddleDataTypeTag,
MACHETEDataTypeSize,
MACHETEDataTypeTag,
MACHETEKernelScheduleTag,
MixedInputKernelScheduleType,
)
# yapf: enable
#
# Generator templating
#
DISPATCH_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
{% for impl_config in impl_configs %}
{% set type_sig = gen_type_sig(impl_config.types) -%}
{% for s in impl_config.schedules %}
extern paddle::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
{%- endfor %}
paddle::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
[[maybe_unused]] auto M = args.A.shape()[0];
[[maybe_unused]] auto N = args.B.shape()[1];
[[maybe_unused]] auto K = args.A.shape()[1];
if (!args.maybe_schedule) {
{%- for cond, s in impl_config.heuristic %}
{%if cond is not none%}if ({{cond}})
{%- else %}else
{%- endif %}
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
}
{%- for s in impl_config.schedules %}
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
{%- endfor %}
PADDLE_ENFORCE(false, "machete_gemm(..) is not implemented ");
}
{%- endfor %}
static inline std::optional<paddle::DataType> maybe_scalartype(
std::optional<paddle::Tensor> const& t) {
if (!t) {
return std::nullopt;
} else {
return t->dtype();
};
}
paddle::Tensor mm_dispatch(MMArgs args) {
auto out_type = args.maybe_out_type.value_or(args.A.dtype());
auto a_type = args.A.dtype();
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set type_sig = gen_type_sig(t) -%}
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
&& a_type == {{PaddleTypeTag[t.a]}}
&& out_type == {{PaddleTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
maybe_g_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
{%- else %}!maybe_g_scales_type{%endif%}
&& {%if t.b_group_zeropoint != void -%}
maybe_g_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
{%- else %}!maybe_g_zeros_type{%endif%}
&& {%if t.b_channel_scale != void -%}
maybe_ch_scales_type == {{PaddleTypeTag[t.b_channel_scale]}}
{%- else %}!maybe_ch_scales_type{%endif%}
&& {%if t.a_token_scale != void -%}
maybe_tok_scales_type == {{PaddleTypeTag[t.a_token_scale]}}
{%- else %}!maybe_tok_scales_type{%endif%}
) {
return mm_dispatch_{{type_sig}}(args);
}
{%- endfor %}
PADDLE_ENFORCE(
false, "machete_mm(..) is not implemented "
"; implemented types are: \\n",
{%- for impl_config in impl_configs %}
{% set t = impl_config.types -%}
"\\t{{gen_type_option_name(t)}}\\n",
{%- endfor %}
"");
}
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args) {
auto out_type = args.maybe_out_type.value_or(args.a_type);
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set schs = impl_config.schedules -%}
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
&& args.a_type == {{PaddleTypeTag[t.a]}}
&& out_type == {{PaddleTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
args.maybe_group_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
{%- else %}!args.maybe_group_scales_type{%endif%}
&& {%if t.b_group_zeropoint != void-%}
args.maybe_group_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
{%- else %}!args.maybe_group_zeros_type{%endif%}
) {
return {
{%- for s in impl_config.schedules %}
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
{%- endfor %}
};
}
{%- endfor %}
return {};
};
}; // namespace machete
"""
IMPL_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
{% for sch in unique_schedules(impl_configs) %}
{% set sch_sig = gen_sch_sig(sch) -%}
struct sch_{{sch_sig}} {
using TileShapeNM = Shape<{{
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
};
{% endfor %}
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set schs = impl_config.schedules -%}
{% set type_sig = gen_type_sig(t) -%}
template<typename Sch>
using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.a]}}, // ElementA
{{DataTypeTag[t.b]}}, // ElementB
{{DataTypeTag[t.out]}}, // ElementD
{{DataTypeTag[t.accumulator]}}, // Accumulator
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
{% for sch in schs %}
{% set sch_sig = gen_sch_sig(sch) -%}
paddle::Tensor
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
}
{%- endfor %}
{%- endfor %}
}; // namespace machete
"""
PREPACK_TEMPLATE = """
#include "../machete_prepack_launcher.cuh"
namespace machete {
paddle::Tensor prepack_B_dispatch(PrepackBArgs args) {
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
{%- for t in types %}
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
if (args.a_type == {{PaddleTypeTag[t.a]}}
&& args.b_type.size_bits() == {{t.b_num_bits}}
&& convert_type == {{PaddleTypeTag[t.convert]}}) {
return prepack_impl<
PrepackedLayoutBTemplate<
{{DataTypeTag[t.a]}}, // ElementA
{{DataTypeTag[b_type]}}, // ElementB
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
}
{%- endfor %}
PADDLE_ENFORCE(false,
"prepack_B_dispatch(..) is not implemented");
}
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass(frozen=True)
class ScheduleConfig:
tile_shape_mn: tuple[int, int]
cluster_shape_mnk: tuple[int, int, int]
kernel_schedule: MixedInputKernelScheduleType
epilogue_schedule: EpilogueScheduleType
tile_scheduler: TileSchedulerType
@dataclass(frozen=True)
class TypeConfig:
a: DataType
b: Union[DataType, MACHETEDataType]
b_group_scale: DataType
b_group_zeropoint: DataType
b_channel_scale: DataType
a_token_scale: DataType
out: DataType
accumulator: DataType
@dataclass(frozen=True)
class PrepackTypeConfig:
a: DataType
b_num_bits: int
convert: DataType
accumulator: DataType
@dataclass
class ImplConfig:
types: TypeConfig
schedules: list[ScheduleConfig]
heuristic: list[tuple[Optional[str], ScheduleConfig]]
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
tile_shape = f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
cluster_shape = (
f"{schedule_config.cluster_shape_mnk[0]}"
+ f"x{schedule_config.cluster_shape_mnk[1]}"
+ f"x{schedule_config.cluster_shape_mnk[2]}"
)
kernel_schedule = MACHETEKernelScheduleTag[schedule_config.kernel_schedule].split("::")[-1]
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split("::")[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
return f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + f"_{epilogue_schedule}_{tile_scheduler}"
# mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}
sch_sig = generate_sch_sig(schedule_config)
for orig, terse in kernel_terse_names_replace.items():
sch_sig = sch_sig.replace(orig, terse)
return sch_sig
# unique type_name
def generate_type_signature(kernel_types: TypeConfig):
return str("".join([MACHETEDataTypeNames[getattr(kernel_types, field.name)] for field in fields(TypeConfig)]))
def generate_type_option_name(kernel_types: TypeConfig):
return ", ".join(
[
f"{field.name.replace('b_', 'with_')+'_type'}=" + MACHETEDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)
def to_cute_constant(value: list[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
else:
return f"Int<{value}>"
if isinstance(value, Iterable):
return [_to_cute_constant(value) for value in value]
else:
return _to_cute_constant(value)
def unique_schedules(impl_configs: list[ImplConfig]):
return list(set(sch for impl_config in impl_configs for sch in impl_config.schedules))
def unsigned_type_with_bitwidth(num_bits):
return {
4: DataType.u4,
8: DataType.u8,
16: DataType.u16,
32: DataType.u32,
64: DataType.u64,
}[num_bits]
template_globals = {
"void": DataType.void,
"DataTypeTag": MACHETEDataTypeTag,
"MACHETEScalarTypeTag": MACHETEDataTypeMACHETEScalarTypeTag,
"PaddleTypeTag": MACHETEDataTypePaddleDataTypeTag,
"KernelScheduleTag": MACHETEKernelScheduleTag,
"EpilogueScheduleTag": EpilogueScheduleTag,
"TileSchedulerTag": TileSchedulerTag,
"to_cute_constant": to_cute_constant,
"gen_sch_sig": generate_terse_sch_sig,
"gen_type_sig": generate_type_signature,
"unique_schedules": unique_schedules,
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
"gen_type_option_name": generate_type_option_name,
}
def create_template(template_str):
template = jinja2.Template(template_str)
template.globals.update(template_globals)
return template
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
sources = []
sources.append(
(
"machete_mm_dispatch",
mm_dispatch_template.render(impl_configs=impl_configs),
)
)
prepack_types = []
for impl_config in impl_configs:
convert_type = (
impl_config.types.a
if impl_config.types.b_group_scale == DataType.void
else impl_config.types.b_group_scale
)
prepack_types.append(
PrepackTypeConfig(
a=impl_config.types.a,
b_num_bits=MACHETEDataTypeSize[impl_config.types.b],
convert=convert_type,
accumulator=impl_config.types.accumulator,
)
)
def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now we we can just use the first accumulator type seen since
# the tensor core shapes/layouts don't vary based on accumulator
# type so we can generate less code this way
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
unique_prepack_types = []
prepack_types_seen = set()
for prepack_type in prepack_types:
key = prepacked_type_key(prepack_type)
if key not in prepack_types_seen:
unique_prepack_types.append(prepack_type)
prepack_types_seen.add(key)
sources.append(
(
"machete_prepack",
prepack_dispatch_template.render(
types=unique_prepack_types,
),
)
)
# Split up impls across files
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
num_impls_per_file = math.ceil(num_impls / num_impl_files)
files_impls: list[list[ImplConfig]] = [[]]
curr_num_impls_assigned = 0
curr_impl_in_file = 0
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
while curr_num_impls_assigned < num_impls:
room_left_in_file = num_impls_per_file - curr_impl_in_file
if room_left_in_file == 0:
files_impls.append([])
room_left_in_file = num_impls_per_file
curr_impl_in_file = 0
curr_ic = curr_impl_configs[-1]
if len(curr_ic.schedules) >= room_left_in_file:
# Break apart the current impl config
tmp_ic = deepcopy(curr_ic)
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
files_impls[-1].append(tmp_ic)
else:
files_impls[-1].append(curr_ic)
curr_impl_configs.pop()
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
curr_impl_in_file += len(files_impls[-1][-1].schedules)
for part, file_impls in enumerate(files_impls):
sources.append(
(
f"machete_mm_impl_part{part+1}",
mm_impl_template.render(impl_configs=file_impls),
)
)
return sources
def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
sch_common_params = dict(
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
default_tile_heuristic_config = {
# M = 257+
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
"M > 256": ((128, 256), (2, 1, 1)),
# M = 129-256
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
"M > 128": ((128, 256), (2, 1, 1)),
# M = 65-128
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
"M > 64": ((128, 128), (2, 1, 1)),
# M = 33-64
"M > 40 && K <= 6144 && N <= 6144": ((128, 32), (2, 1, 1)),
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
"M > 32": ((128, 64), (2, 1, 1)),
# M = 17-32
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
"M > 16": ((256, 32), (2, 1, 1)),
# M = 1-16
"N >= 26624": ((256, 16), (1, 1, 1)),
None: ((128, 16), (1, 1, 1)),
}
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items()
]
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
# Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules = []
for _, schedule_config in heuristic:
if schedule_config not in schedules:
schedules.append(schedule_config)
return schedules
impl_configs = []
GPTQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
)
for b in (MACHETEDataType.u4b8, MACHETEDataType.u8b128)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(
GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
]
output_dir = os.path.join(SCRIPT_DIR, "generated")
# Delete the "generated" directory if it exists
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# Create the "generated" directory
os.makedirs(output_dir)
# Render each group of configurations into separate files
for filename, code in create_sources(impl_configs):
filepath = os.path.join(output_dir, f"{filename}.cu")
with open(filepath, "w") as output_file:
output_file.write(code)
print(f"Rendered template to {filepath}")
if __name__ == "__main__":
generate()

View File

@@ -1,31 +0,0 @@
#pragma once
#include "utils/machete_collective_builder.cuh"
#include "machete_mainloop.cuh"
namespace cutlass::gemm::collective {
using namespace cute;
struct MacheteKernelTag {};
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct MacheteCollectiveBuilder<
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective

View File

@@ -1,85 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from typing import Union
from cutlass_library import (
DataType,
DataTypeNames,
DataTypeSize,
DataTypeTag,
KernelScheduleTag,
KernelScheduleType,
enum_auto,
)
#
# Extend cutlass library with custom types, and missing values
#
class MACHETEDataType(enum.Enum):
u4b8 = enum_auto()
u8b128 = enum_auto()
class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedPingpong = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
MACHETEDataTypeNames: dict[Union[MACHETEDataType, DataType], str] = {
**DataTypeNames, # type: ignore
**{
MACHETEDataType.u4b8: "u4b8",
MACHETEDataType.u8b128: "u8b128",
},
}
MACHETEDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
**DataTypeTag, # type: ignore
**{
MACHETEDataType.u4b8: "cutlass::machete_uint4b8_t",
MACHETEDataType.u8b128: "cutlass::machete_uint8b128_t",
},
}
MACHETEDataTypeSize: dict[Union[MACHETEDataType, DataType], int] = {
**DataTypeSize, # type: ignore
**{
MACHETEDataType.u4b8: 4,
MACHETEDataType.u8b128: 8,
},
}
MACHETEDataTypeMACHETEScalarTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
MACHETEDataType.u4b8: "machete::kU4B8",
MACHETEDataType.u8b128: "machete::kU8B128",
DataType.u4: "machete::kU4",
DataType.u8: "machete::kU8",
DataType.s4: "machete::kS4",
DataType.s8: "machete::kS8",
DataType.f16: "machete::kFloat16",
DataType.bf16: "machete::kBfloat16",
}
MACHETEDataTypePaddleDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
DataType.u8: "paddle::DataType::UINT8",
DataType.s8: "paddle::DataType::INT8",
DataType.e4m3: "paddle::DataType::FLOAT8_E4M3FN",
DataType.s32: "paddle::DataType::INT32",
DataType.f16: "paddle::DataType::FLOAT16",
DataType.bf16: "paddle::DataType::BFLOAT16",
DataType.f32: "paddle::DataType::FLOAT32",
}
MACHETEKernelScheduleTag: dict[Union[MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
},
}

View File

@@ -1,35 +0,0 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace machete {
using namespace cute;
// get an interleaved block layout where each element consecutive element has a
// stride of bit_stride and the block width is blk_bit_width,
// examples:
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
template <typename T, int bit_stride, int blk_bit_width>
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
static_assert(blk_bit_width % bit_stride == 0);
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
// identity layout
return Layout<Shape<Int<elems_per_blk>>>{};
} else {
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
constexpr auto num_strides = elems_per_blk / elems_per_stride;
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
Stride<Int<elems_per_stride>, Int<1>>>{};
}
}
}; // namespace machete

File diff suppressed because it is too large Load Diff

View File

@@ -1,88 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
template <typename T>
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
}
paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
int64_t b_type_id,
std::optional<paddle::DataType> const& maybe_out_type,
std::optional<paddle::Tensor> const& maybe_group_scales,
std::optional<paddle::Tensor> const& maybe_group_zeros,
int64_t maybe_group_size,
std::optional<paddle::Tensor> const& maybe_channel_scales,
std::optional<paddle::Tensor> const& maybe_token_scales,
std::string maybe_schedule) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
std::optional<std::string> maybe_schedule_opt;
if (maybe_schedule == "") {
maybe_schedule_opt = std::nullopt;
} else {
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
}
return machete::mm_dispatch({.A = A,
.B = B,
.b_type = b_type,
.maybe_out_type = maybe_out_type,
.maybe_group_scales = maybe_group_scales,
.maybe_group_zeros = maybe_group_zeros,
.maybe_group_size = maybe_group_size_opt,
.maybe_channel_scales = maybe_channel_scales,
.maybe_token_scales = maybe_token_scales,
.maybe_schedule = maybe_schedule_opt});
}
std::vector<paddle::Tensor> MacheteMMKernel(
paddle::Tensor const& A, paddle::Tensor const& B,
paddle::optional<paddle::Tensor> const& maybe_group_scales,
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
paddle::optional<paddle::Tensor> const& maybe_token_scales,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule
) {
machete::ScalarTypeId b_type_id;
paddle::DataType maybe_out_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
if (maybe_out_type_str == "float16") {
maybe_out_type = paddle::DataType::FLOAT16;
} else if (maybe_out_type_str == "bfloat16") {
maybe_out_type = paddle::DataType::BFLOAT16;
} else {
maybe_out_type = A.dtype();
}
auto out = mm(A, B, b_type_id, maybe_out_type,
ConvertToStdOptional<paddle::Tensor>(maybe_group_scales),
ConvertToStdOptional<paddle::Tensor>(maybe_group_zeros),
maybe_group_size,
ConvertToStdOptional<paddle::Tensor>(maybe_channel_scales),
ConvertToStdOptional<paddle::Tensor>(maybe_token_scales),
maybe_schedule);
return {out};
}

View File

@@ -1,305 +0,0 @@
#pragma once
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "utils/cute_utils.cuh"
#include "utils/machete_numeric_conversion.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "utils/paddle_utils.hpp"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
typename ScheduleConfig>
struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
static constexpr bool with_group_zeropoints =
!std::is_same_v<GroupZeroT, void>;
static constexpr bool with_channel_scales =
!std::is_same_v<ChannelScaleT, void>;
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementAccumulator = AccumulatorT;
using ElementCompute = AccumulatorT; // For Epilogue
// Use dummy values when we don't have scales or zeropoints
using ElementZGroup =
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
using ElementSGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementConvertGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementSChannel =
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
using ElementSToken =
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
using BTypeTuple = cute::conditional_t<
with_group_scales,
cute::conditional_t<with_group_zeropoints,
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
cute::tuple<ElementB, ElementSGroup>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZGroup = StrideSGroup;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
static_assert(
(!with_channel_scales && !with_token_scales) ||
((with_channel_scales && with_token_scales) &&
std::is_same_v<ElementSChannel, ElementSToken>),
"Currently token and channel scales (if present) must be the same type");
// Currently only supports float scales
using ChTokScalesEpilogue =
typename fastdeploy::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
TileShape>;
static_assert((with_channel_scales || with_token_scales) ||
(std::is_same_v<ElementSChannel, float> &&
std::is_same_v<ElementSToken, float>),
"Currently token and channel scales (if present) must be float "
"(and if one is present the other must be too)");
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90AccFetch>;
using EVTCompute =
std::conditional_t<with_channel_scales || with_token_scales,
typename ChTokScalesEpilogue::EVTCompute,
StoreEpilogueCompute>;
// EVTCompute
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::MacheteCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
static Arguments create_arguments(
cudaStream_t stream,
paddle::Tensor const& A, // MxK matrix
paddle::Tensor const& B, // KxN prepacked matrix
paddle::Tensor& D, // MxN matrix
std::optional<paddle::Tensor> const& maybe_g_scales, // scale_KxN matrix
std::optional<paddle::Tensor> const& maybe_g_zeros, // scale_KxN matrix
std::optional<int64_t> maybe_group_size,
std::optional<paddle::Tensor> const& maybe_ch_scales, // len N vector
std::optional<paddle::Tensor> const& maybe_tok_scales) // len M vector
{
static_assert(!with_group_zeropoints || with_group_scales);
int M = A.shape()[0], N = B.shape()[1], K = A.shape()[1];
PD_CHECK(D.shape()[0] == M && D.shape()[1] == N);
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_S_group =
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
auto layout_Z_group =
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
auto unwrap = [](auto const& t) {
return t ? t->data() : nullptr;
};
auto A_ptr = static_cast<ElementA const*>(A.data());
auto B_ptr = static_cast<ElementB const*>(B.data());
auto D_ptr = static_cast<ElementD*>(D.data());
auto S_group_ptr =
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
auto S_channel_ptr =
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
auto S_token_ptr =
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
PD_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
PD_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_group_scales) {
PD_CHECK(S_group_ptr && layout_S_group);
PD_CHECK((size<0>(*layout_S_group) == scale_k &&
size<1>(*layout_S_group) == N));
} else {
PD_CHECK(!S_group_ptr, "Scales not supported");
}
if constexpr (with_group_zeropoints) {
PD_CHECK(Z_group_ptr && layout_Z_group);
PD_CHECK((size<0>(*layout_Z_group) == scale_k &&
size<1>(*layout_Z_group) == N));
PD_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
"Scales and zeros must have the same layout");
} else {
PD_CHECK(!Z_group_ptr, "Zeropoints not supported");
}
if constexpr (with_channel_scales || with_token_scales) {
PD_CHECK(
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
MainloopArguments mainloop_arguments{};
// {Accum, C, C_layout, D, D}
EpilogueArguments epilogue_arguments{};
if constexpr (with_channel_scales || with_token_scales) {
epilogue_arguments =
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
*maybe_ch_scales, *maybe_tok_scales),
nullptr,
{},
D_ptr,
stride_Dt};
} else {
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
}
if constexpr (with_group_scales && with_group_zeropoints) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
} else if constexpr (with_group_scales) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
PD_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
PD_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete

View File

@@ -1,78 +0,0 @@
#pragma once
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "utils/paddle_utils.hpp"
#include "utils/scalar_type.h"
namespace machete {
struct MMArgs {
paddle::Tensor const& A;
paddle::Tensor const& B;
machete::ScalarType const& b_type;
std::optional<paddle::DataType> const& maybe_out_type;
std::optional<paddle::Tensor> const& maybe_group_scales;
std::optional<paddle::Tensor> const& maybe_group_zeros;
std::optional<int64_t> maybe_group_size;
std::optional<paddle::Tensor> const& maybe_channel_scales;
std::optional<paddle::Tensor> const& maybe_token_scales;
std::optional<std::string> maybe_schedule;
};
struct SupportedSchedulesArgs {
paddle::DataType a_type;
machete::ScalarType b_type;
std::optional<paddle::DataType> maybe_group_scales_type;
std::optional<paddle::DataType> maybe_group_zeros_type;
std::optional<paddle::DataType> maybe_channel_scales_type;
std::optional<paddle::DataType> maybe_token_scales_type;
std::optional<paddle::DataType> maybe_out_type;
};
paddle::Tensor mm_dispatch(MMArgs args);
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args);
template <typename MacheteKernel>
paddle::Tensor run_impl(MMArgs args) {
// const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
// auto device = args.A.device();
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto place = args.A.place();
cudaStream_t stream = args.A.stream();
int M = args.A.shape()[0];
int N = args.B.shape()[1];
int K = args.A.shape()[1];
// Allocate output
paddle::Tensor D = paddle::empty(
{M, N},
equivalent_scalar_type_v<typename MacheteKernel::ElementD>,
place);
auto arguments = MacheteKernel::create_arguments(
stream, //
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
args.maybe_group_size, args.maybe_channel_scales,
args.maybe_token_scales);
PD_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
int S = static_cast<int>(workspace_size);
// phi::Allocator* allocator = paddle::GetAllocator(place);
// auto workspace = allocator->Allocate(workspace_size);
// MacheteKernel::run(arguments, workspace->ptr(), stream);
// paddle::Tensor workspace = paddle::empty({S}, paddle::DataType::UINT8, place);
paddle::Tensor workspace = GetEmptyTensor({S}, paddle::DataType::UINT8, place);
MacheteKernel::run(arguments, workspace.data(), stream);
return D;
};
}; // namespace machete

View File

@@ -1,73 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
paddle::Tensor prepack_B(
paddle::Tensor const& B, paddle::DataType const& a_type, int64_t b_type_id,
std::string const& maybe_group_scales_type_str) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<paddle::DataType> maybe_group_scales_type;
if (maybe_group_scales_type_str == "float16") {
maybe_group_scales_type = paddle::DataType::FLOAT16;
}
else if (maybe_group_scales_type_str == "bfloat16") {
maybe_group_scales_type = paddle::DataType::BFLOAT16;
}
else if (maybe_group_scales_type_str == "float32") {
maybe_group_scales_type = paddle::DataType::FLOAT32;
}
else if (maybe_group_scales_type_str == "") {
maybe_group_scales_type = std::nullopt;
}
else {
PADDLE_ENFORCE(false, "maybe_group_scales_type_str not supported!");
}
return machete::prepack_B_dispatch(
{.B = B,
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type});
}
std::vector<paddle::Tensor> MachetePrepackBKernel(
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str) {
machete::ScalarTypeId b_type_id;
paddle::DataType a_type, maybe_group_scales_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
if (a_type_str == "float16") {
a_type = paddle::DataType::FLOAT16;
}
else if (a_type_str == "bfloat16") {
a_type = paddle::DataType::BFLOAT16;
}
else {
PADDLE_ENFORCE(false, "a_type_str not supported!");
}
auto Bt = paddle::experimental::transpose(B, {1, 0});
paddle::Tensor B_prepacked = prepack_B(Bt, a_type, b_type_id, maybe_group_scales_type_str);
return {B_prepacked};
}

View File

@@ -1,76 +0,0 @@
#pragma once
#include "machete_mm_kernel.cuh"
#include "utils/cute_utils.cuh"
#include "utils/paddle_utils.hpp"
namespace machete {
template <int threads, typename PrepackedLayoutB, typename BInTensor,
typename ElementB>
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
auto constexpr block_size =
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
auto constexpr eles_per_thread = Int<block_size / threads>{};
static_assert(block_size % threads == 0,
"block_size must be divisible by the number of threads");
// Which pre-packed are we responsible for
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
auto tB_in = local_tile(
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
blk_coord);
// Find the start offset in the output for this pre-packed block
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
// Tensor representing a 1:1 mapping to the output space in 1D
auto tB_out_linear =
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
make_layout(make_shape(block_size)));
// Mapping from output space (1D) to input space
auto tB_in_linear = make_tensor(
tB_in.data(),
tB_in.layout()
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
.with_shape(make_shape(block_size)));
// Tile for this specific thread (could have used a TiledCopy but these work
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
// we are also not that concerned with performance for this kernel)
auto thr_tB_in_linear =
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
auto thr_tB_out_linear =
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
copy(thr_tB_in_linear, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B_template(
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
PD_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
PD_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout);
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
prepack_B_kernel<128, PrepackedLayoutB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
}
}; // namespace machete

View File

@@ -1,77 +0,0 @@
#pragma once
#include "machete_prepack_kernel.cuh"
#include "utils/paddle_utils.hpp"
#include "utils/scalar_type.h"
namespace machete {
struct PrepackBArgs {
paddle::Tensor const& B;
paddle::DataType a_type;
machete::ScalarType b_type;
std::optional<paddle::DataType> maybe_group_scales_type;
};
template <typename PrepackedLayoutB>
paddle::Tensor prepack_impl(paddle::Tensor const B) {
// const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
// auto device = B.device();
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
cudaStream_t stream = B.stream();
auto B_ptr = static_cast<ElementB const*>(B.data());
// elements per storage item for B
auto eles_per_storage =
(SizeOf(B.dtype()) * 8) / cute::sizeof_bits_v<ElementB>;
// paddle B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
// auto Bt_packed = B.transpose();
auto Bt_packed = paddle::experimental::transpose(B, {1, 0});
PD_CHECK(
(B.shape()[0] * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
PD_CHECK(B.shape()[1] % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
// auto const l_Bt_packed = make_cute_layout<StrideB>(B, "B");
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
PD_CHECK(stride<1>(l_Bt_packed) == 1, "stride<1>(l_Bt_packed) must be 1");
// clang-format off
auto const layout_Bt = make_layout(
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
return idx == 1 ? ele * eles_per_storage : ele;
}),
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
return idx != 1 ? ele * eles_per_storage : ele;
}));
// clang-format on
// Allocate output
paddle::Tensor D = paddle::empty_like(B);
prepack_B_template<PrepackedLayoutB>(
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.data()));
return D;
};
paddle::Tensor prepack_B_dispatch(PrepackBArgs args);
}; // namespace machete

View File

@@ -1,249 +0,0 @@
#pragma once
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "utils/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
struct IlvBlkLayoutAuto {};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementAccumulator = AccumulatorT;
using ElementMma = MmaType;
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
// in those cases case we use a LUT using prmt instructions to upconvert and
// is more efficient if the data is not interleaved For 8bit+ prmt
// instructions makes non-interleaved layouts efficient enough we don't need
// iterleaved layouts (and can reuse more of the existing cutlass converts)
static constexpr bool should_interleave =
sizeof_bits_v<ElementB> <= 4 &&
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
!std::is_same_v<ElementConvert_, int8_t>;
// Only use interleaved layouts for subbyte weights,
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<
should_interleave,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using PPBlockShape_NK = Shape<_128, _64>;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
// Return iterleaved layout
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
auto layout_no_interleave =
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
return layout_no_interleave;
} else {
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % size(ilvdBlk) == 0,
"FrgV must be divisible by size(ilvdBlk)");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
// Return iterleaved layout
return make_layout(
get<0>(layout_no_interleave),
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
// do (M,K) -> (athrid, val) -> (storage_idx)
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
Shape_NKL shape_mkl) {
auto layout = TVbNbKL_to_offset(shape_mkl);
// for 4-bit elements, having >= 64 values per column
// allows TMA to load full 32-byte sectors
auto inner_layout =
make_layout(make_shape(_256{}, size<0>(layout) / _256{}));
return make_layout(inner_layout, get<1>(layout), get<2>(layout));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// (BlocksN, BlocksK, L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto stride = size(PPBlockShape_NK{});
// (BlocksN, BlocksK, L) -> (storage_idx)
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
make_layout(size<1>(PPBlockShape_NK{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
return tiled_A.compose(ppblock_TV_to_NK(), _);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
return blocked_product(ppblock_NK_to_TV(),
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
}
};
}; // namespace machete

View File

@@ -1,72 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
template <typename T>
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
}
std::vector<std::string> supported_schedules(
paddle::DataType a_type, int64_t b_type_id,
std::optional<paddle::DataType> maybe_group_scales_type,
std::optional<paddle::DataType> maybe_group_zeros_type,
std::optional<paddle::DataType> maybe_channel_scales_type,
std::optional<paddle::DataType> maybe_token_scales_type,
std::optional<paddle::DataType> maybe_out_type) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
auto schedules = machete::supported_schedules_dispatch({
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type,
.maybe_group_zeros_type = maybe_group_zeros_type,
.maybe_channel_scales_type = maybe_channel_scales_type,
.maybe_token_scales_type = maybe_token_scales_type,
.maybe_out_type = maybe_out_type
});
return schedules;
}
std::vector<std::string> MacheteSupportedSchedules(
std::string const& a_type_str, std::string const& b_type_str) {
machete::ScalarTypeId b_type_id;
paddle::DataType a_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
if (a_type_str == "bfloat16") {
a_type = paddle::DataType::BFLOAT16;
} else if (a_type_str == "float16") {
a_type = paddle::DataType::FLOAT16;
} else {
PADDLE_ENFORCE(false, "a_type_str not supported!");
}
std::optional<paddle::DataType> maybe_group_scales_type = std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_out_type = std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_group_zeros_type = std::nullopt;
std::optional<paddle::DataType> maybe_channel_scales_type = std::nullopt;
std::optional<paddle::DataType> maybe_token_scales_type = std::nullopt;
auto schedules = supported_schedules(a_type, b_type_id,
maybe_group_scales_type,
maybe_group_zeros_type,
maybe_channel_scales_type,
maybe_token_scales_type,
maybe_out_type);
return schedules;
}

View File

@@ -1,69 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/cute_utils.cuh
#pragma once
#include <cute/tensor.hpp>
namespace cute {
////////////////////////////////////////////////////////////////////
// layout utils
////////////////////////////////////////////////////////////////////
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template <size_t... I, typename Layout>
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
return cute::make_layout(cute::get<I>(l)...);
}
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>) {
return true;
} else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
return true;
}
return false;
}
}
////////////////////////////////////////////////////////////////////
// Pointer utils
////////////////////////////////////////////////////////////////////
template <class PointerType>
static constexpr auto get_logical_ptr(PointerType* ptr) {
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
return cute::subbyte_iterator<PointerType>(ptr);
} else {
return ptr;
}
}
////////////////////////////////////////////////////////////////////
// Misc utils
////////////////////////////////////////////////////////////////////
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
} else if constexpr (bits % 64 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<64>{};
} else if constexpr (bits % 32 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<32>{};
} else if constexpr (bits % 16 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<16>{};
} else {
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}
}; // namespace cute

View File

@@ -1,44 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_collective_builder.cuh
#pragma once
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective {
using namespace cute;
//
// MacheteCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct CutlassKernelTag {};
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, class Enable = void>
struct MacheteCollectiveBuilder {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct MacheteCollectiveBuilder<
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType> {
using CollectiveOp = typename CollectiveBuilder<
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
};
}; // namespace cutlass::gemm::collective

View File

@@ -1,51 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_custom_types.cuh
#pragma once
#include "cutlass/integer_subbyte.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed = false>
struct machete_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
using Base = integer_subbyte<Bits, Signed>;
using Storage = typename Base::Storage;
using xint_t = typename Base::xint_t;
using Base::bits_mask_;
using Base::sign_mask_;
using Base::storage;
//
// Methods
//
/// No operation
machete_biased_integer_subbyte() = default;
/// Conversion from integer type
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(int value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(unsigned value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(double value)
: Base(value) {}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using machete_uint4b8_t = machete_biased_integer_subbyte<4, 8>; // u4b8
using machete_uint8b128_t = machete_biased_integer_subbyte<8, 128>; // u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed>
struct sizeof_bits<machete_biased_integer_subbyte<Bits, Bias, Signed>> {
static constexpr int value = Bits;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@@ -1,993 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
#pragma once
#include "cutlass/numeric_conversion.h"
#include "machete_custom_types.cuh"
#include "cute_utils.cuh"
#include "machete_type_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: machete_uint4b8_t, machete_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)
namespace cutlass {
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
class Enable = void>
struct InterleavedNumericArrayConverter {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
if (cute::elect_one_sync()) {
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
return {};
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round>
struct InterleavedNumericArrayConverter<
IlvBlkLayout, T, S, N, Round,
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return Converter::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
using result_packed_8_t = Array<T, 8>;
using result_packed_4_t = Array<T, 4>;
using result_packed_2_t = Array<T, 2>;
using src_packed_8_t = Array<S, 8>;
using src_packed_4_t = Array<S, 4>;
using src_packed_2_t = Array<S, 2>;
static_assert(N % 2 == 0, "N must be a multiple of 2");
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
static constexpr auto src_elems_per_32bit_reg =
32 / cutlass::sizeof_bits_v<S>;
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc>
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
if constexpr (sizeof(PackedSrc) == 1) {
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 2) {
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 4) {
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
} else {
static_assert(sizeof(PackedSrc) == 8);
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
}
}
// The core converter uses bit tricks to construct a known FP16 number, then
// does a subtraction in FP16 for the final result.
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
static_assert(PackedResultType::kElements == 2 ||
PackedResultType::kElements == 4 ||
PackedResultType::kElements == 8,
"Invalid PackedResultType must be 2, 4 or 8.");
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
ArrayConverterPacked32Bit<RegConvert32bit,
typename result_type::Element,
typename source_type::Element, N>;
if constexpr (src_elems_per_32bit_reg >= 8) {
detail::VectorizedConverter::convert<
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
} else if constexpr (src_elems_per_32bit_reg >= 4) {
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
} else {
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
src_packed_2_t>(result, source);
}
return result;
}
};
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
// into 2 32bit register.
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
uint32_t src) {
cutlass::AlignedArray<uint32_t, 2> r;
// Determines if the value is in the top half of the LUT if set or
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
// selects the correct candidate. When elements in final_prmt_base
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
uint32_t high_bit = (src & 0x88888888) >> 1;
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct high or low candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore the high bit when indexing into LUT, for each 4bit value
// we index into both the high and low candidates then use
// high_bit | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
(uint32_t(d) << 24);
};
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
uint32_t final_prmt_idx = final_prmt_base | high_bit;
// This uses a look up table to convert packed int4s to packed int8s,
// using the int4 value as the index to prmt. It first select both the
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
// select the correct candidate.
asm volatile(
"{\n"
" .reg .b32 low, high;\n"
" prmt.b32 low, %1, %2, %5;\n"
" prmt.b32 high, %3, %4, %5;\n"
" prmt.b32 %0, low, high, %6;\n"
"}\n"
: "=r"(r[ii])
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
"r"(final_prmt_idx));
}
return r;
};
// for Array<int8_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, machete_uint4b8_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
0xFC, 0xFD, 0xFE, 0xFF, //
0x00, 0x01, 0x02, 0x03, //
0x04, 0x05, 0x06, 0x07>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float_e4m3_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::float_e4m3_t, machete_uint4b8_t, N, Round> {
using result_type = Array<cutlass::float_e4m3_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
0xC8, 0xC4, 0xC0, 0xB8, //
0x00, 0x38, 0x40, 0x44, //
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, machete_uint4b8_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
// Below constructs the following temporary:
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
// We use inline asm instead of __byte_perm intrinsic since we don't want
// the documented (& 0x7) on the index. NVCC might be able to optimize it
// out since the index is a constexpr, but we choose to be safe about it
// here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 4,
"Too many inputs for F16 -> I4 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a fp16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the FP16 to the correct value for the
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
// where x1 in the high nibble and x0 is the low nibble then using hfma
// to subtract 1032 from that
// The AND does the following:
// 1) Clear the set bits for the int4 we will ignore.
// We use lop3 so that we can use 1 instruction for AND and XOR.
static constexpr uint32_t xor_mask = 0x64006400;
static constexpr uint32_t and_mask = 0xFFF0FF0F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 hfmas that do the following:
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, machete_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
// For high nibble:
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
// - {72, 72}
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<uint4_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
// For high nibble:
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<machete_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, machete_uint8b128_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<machete_uint8b128_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(r[ii])
: "r"(src), "n"(start_byte_for_fp16),
"r"(prmt_indices[ii]));
}
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
static constexpr uint32_t bias_rep = 0x64806480;
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hsub2(fp16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float, N> <= Array<machete_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<float, machete_uint8b128_t, N, Round> {
using result_type = Array<float, N>;
using source_type = Array<machete_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
// u8x4 source and stores the result in r (without introducing extra
// cvt.u32.u8 instruction)
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
// Subtract the magic number 0x4B000000 from tmp in floating-point
// arithmetic to obtain final result
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
}
return r;
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint4b8_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src_reg = src_[0];
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4,
"Too many inputs for uint4b8_t -> BF16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a BF16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the BF16 to the correct value for the
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
// and subtracting 136 to get {x1, x0}
static constexpr uint32_t xor_mask = 0x43004300;
static constexpr uint32_t and_mask = 0x000F000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static constexpr uint32_t bias_rep = 0x43084308;
const __nv_bfloat162& bias =
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, machete_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<uint4_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint8b128_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<machete_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
using src_packed_4_t = Array<machete_uint8b128_t, 4>;
using src_packed_2_t = Array<machete_uint8b128_t, 2>;
// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round> is
// implemented
using ScalarConverter =
NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round>;
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
platform::is_same<PackedResultType, result_packed_4_t>::value),
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
"convert dispatch.");
NumericArrayConverter<float, machete_uint8b128_t, PackedResultType::kElements,
Round>
convert_uint8_to_f32;
Array<float, PackedResultType::kElements> tmp =
convert_uint8_to_f32(source);
NumericArrayConverter<cutlass::bfloat16_t, float,
PackedResultType::kElements, Round>
convert_f32_to_bf16_;
return convert_f32_to_bf16_(tmp);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
NumericArrayConverter<typename result_type::Element,
typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#endif
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<cutlass::half_t, N>;
struct RegConvert {
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <typename PackedResultType, int src_regs>
CUTLASS_DEVICE static PackedResultType convert(
Array<uint32_t, src_regs> src) {
// Hold output int8s in reg. We need 1 reg for every 4 elements
using RegArray = cutlass::AlignedArray<
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
RegArray r;
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
*reinterpret_cast<half2*>(&src[0]) =
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
if constexpr (src_regs > 1) {
*reinterpret_cast<half2*>(&src[1]) =
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
}
static_assert(PackedResultType::kElements <= 4);
uint32_t uint8s;
static constexpr uint32_t MASK_0246 = 0x6420;
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(uint8s)
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
"n"(MASK_0246));
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
return reinterpret_cast<PackedResultType&>(int8s);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -1,43 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
#include "cuda_bf16.h"
#include "machete_custom_types.cuh"
namespace cutlass {
template <typename T>
struct nameof {
static constexpr char const* value = "unknown";
};
template <typename T>
inline constexpr auto nameof_v = nameof<T>::value;
#define NAMEOF_TYPE(T) \
template <> \
struct nameof<T> { \
static constexpr char const* value = #T; \
};
NAMEOF_TYPE(float_e4m3_t)
NAMEOF_TYPE(float_e5m2_t)
NAMEOF_TYPE(half_t)
NAMEOF_TYPE(nv_bfloat16)
NAMEOF_TYPE(bfloat16_t)
NAMEOF_TYPE(float)
NAMEOF_TYPE(int4b_t)
NAMEOF_TYPE(int8_t)
NAMEOF_TYPE(int32_t)
NAMEOF_TYPE(int64_t)
NAMEOF_TYPE(machete_uint4b8_t)
NAMEOF_TYPE(uint4b_t)
NAMEOF_TYPE(uint8_t)
NAMEOF_TYPE(machete_uint8b128_t)
NAMEOF_TYPE(uint32_t)
NAMEOF_TYPE(uint64_t)
}; // namespace cutlass

View File

@@ -1,161 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
#pragma once
#include "helper.h"
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;
namespace cute {
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}
template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
return make_shape(f(I)...);
}
}; // namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
return detail::make_shape_from_idx(f, make_seq<N>{});
}
}; // namespace cute
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(paddle::Tensor const& tensor,
std::string_view name = "tensor") {
PD_CHECK(tensor.shape().size() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.shape().size()) {
if constexpr (cute::is_static_v<StrideEle>) {
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
return StrideEle{};
} else {
if (tensor.shape()[idx] == 1) {
// use 0 stride for dims with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.strides()[idx];
}
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.shape().size())
return tensor.shape()[idx];
else
return int64_t(1);
});
return make_layout(shape, stride);
}
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<paddle::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
if (tensor) {
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
} else {
return std::optional<Layout>{};
}
}
//
// Paddle dtype to Cutlass Type (equivalent_cutlass_type)
//
template <typename T>
struct equivalent_cutlass_type {
using type = T;
};
template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<phi::dtype::float16> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
};
//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
};
template <typename T>
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = phi::dtype::float16;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = phi::dtype::bfloat16;
};
// get equivalent c10::ScalarType tag from compile time type
template <typename T>
static inline constexpr paddle::DataType equivalent_scalar_type_v =
phi::CppTypeToDataType<equivalent_scalar_type_t<T>>::Type();

View File

@@ -1,372 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include <optional>
#include <variant>
namespace machete {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
int32_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
}
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits, false, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(uint8_t exponent,
uint8_t mantissa) {
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
// "use `float_IEEE754` constructor for floating point types that "
// "follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr);
}
uint8_t const exponent; // size of the exponent field (0 for integer types)
uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
using Id = int64_t;
private:
// Field size in id
template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int {
return acc + member_id_field_width<decltype(member)>();
},
0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8,
"ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result,
auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
<< bit_offset,
bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); },
tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const { return signed_; }
constexpr bool is_integer() const { return exponent == 0; }
constexpr bool is_floating_point() const { return exponent > 0; }
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754;
}
constexpr bool has_nans() const {
return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
constexpr bool has_bias() const { return bias != 0; }
private:
double _floating_point_max() const {
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
PADDLE_ENFORCE(exponent < 11,
"Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double =
max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw =
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
// PADDLE_ENFORCE(size_bits() < 64 || size_bits() == 64 && is_signed(),
// "Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
// PADDLE_ENFORCE(is_signed(),
// "We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
// PADDLE_ENFORCE(!is_signed() || size_bits() <= 64,
// "Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret = "float" + std::to_string(size_bits()) + "_e" +
std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only &&
nan_repr == other.nan_repr;
}
};
using ScalarTypeId = machete::ScalarType::Id;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = machete::ScalarType::int_(4);
static inline constexpr auto kU4 = machete::ScalarType::uint(4);
static inline constexpr auto kU4B8 = machete::ScalarType::uint(4, 8);
static inline constexpr auto kS8 = machete::ScalarType::int_(8);
static inline constexpr auto kU8 = machete::ScalarType::uint(8);
static inline constexpr auto kU8B128 = machete::ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
machete::ScalarType::float_(2, 1, true, machete::ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
machete::ScalarType::float_(3, 2, true, machete::ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
machete::ScalarType::float_(4, 3, true, machete::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = machete::ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = machete::ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = machete::ScalarType::float_IEEE754(5, 10);
// // Fixed width style names, generally following:
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
constexpr auto kInt4 = kS4;
constexpr auto kUint4 = kU4;
constexpr auto kUint4b8 = kU4B8;
constexpr auto kInt8 = kS8;
constexpr auto kUint8 = kU8;
constexpr auto kUint8b128 = kU8B128;
constexpr auto kFloat4_e2m1f = kFE2M1f;
constexpr auto kFloat6_e3m2f = kFE3M2f;
constexpr auto kFloat8_e5m2 = kFE5M2;
constexpr auto kFloat16_e8m7 = kFE8M7;
constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
constexpr auto kHalf = kFE5M10;
constexpr auto kFloat16 = kHalf;
constexpr auto kFloat16Id = kFloat16.id();
constexpr auto kInt32 = phi::DataType::INT32;
constexpr auto kInt64 = phi::DataType::INT64;
constexpr auto kBool = phi::DataType::BOOL;
constexpr auto kFloat8_e4m3fn = phi::DataType::FLOAT8_E4M3FN;
constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
constexpr auto kFloat32 = phi::DataType::FLOAT32;
constexpr auto kByte = phi::DataType::INT8;
}; // namespace machete

View File

@@ -1,330 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn.h"
std::vector<paddle::Tensor> MobaAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& block_tables,
const paddle::Tensor& rope_sin_cos,
const paddle::Tensor& k_block_means,
const paddle::optional<paddle::Tensor>& attn_gate_weight,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time,
const int moba_encoder_top_k_left,
const int moba_encoder_top_k_right,
const int moba_use_encoder_seq_limit,
const int moba_decoder_top_k_left,
const int moba_decoder_top_k_right,
const int moba_use_decoder_seq_limit,
const bool moba_use_mlp,
const std::string &cache_quant_type_str) {
paddle::Tensor out = paddle::empty({qkv.dims()[0], head_num * head_dim}, qkv.dtype(), qkv.place());
if (max_dec_len_this_time > 0) {
MobaDecoderAttnWriteCacheKv(
qkv,
q_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
seq_len_decoder,
key_cache,
value_cache,
block_tables,
rope_sin_cos,
k_block_means,
qkv_bias,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_len,
cache_quant_type_str);
auto qk_gate_weight = MobaQKGemm(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_dec_len_this_time,
max_dec_len_this_time,
head_num,
kv_head_num,
true,
moba_use_decoder_seq_limit
)[0];
auto qk_gate_topk_idx = QkSortDecoder(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
head_num,
kv_head_num,
moba_decoder_top_k_left,
moba_decoder_top_k_right,
moba_use_decoder_seq_limit
)[0];
MobaDecoderAttn(
q_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
key_cache,
value_cache,
block_tables,
k_block_means,
out,
qk_gate_topk_idx,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_len,
moba_use_decoder_seq_limit,
max_dec_len_this_time,
max_dec_len_this_time,
cache_quant_type_str
);
}
if (max_enc_len_this_time > 0) {
FusedBlockMeanAndRope(
qkv,
k_block_means,
q_input,
k_input,
v_input,
rope_sin_cos,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
qkv_bias,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_enc_len_this_time,
cache_quant_type_str
);
MobaEncoderAttnWriteCacheKv(
k_input,
v_input,
cu_seq_k,
seq_len_encoder,
seq_len_decoder,
key_cache,
value_cache,
block_tables,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_enc_len_this_time,
cache_quant_type_str
);
GetKVFromCache(
k_input,
v_input,
cu_seq_k,
seq_len_encoder,
seq_len_decoder,
key_cache,
value_cache,
block_tables,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time + max_dec_len_this_time,
cache_quant_type_str
);
paddle::Tensor *k_gate_weight = const_cast<paddle::Tensor*>(&k_block_means);
if (moba_use_mlp && attn_gate_weight) {
paddle::Tensor k_gate_mlp = MobaMlpEinsum(
k_input,
attn_gate_weight.get(),
seq_len_encoder,
seq_len_decoder,
cu_seq_k,
max_seq_len,
kv_head_num
)[0];
k_gate_weight = &k_gate_mlp;
}
auto qk_gate_weight = MobaQKGemm(
q_input,
*k_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_enc_len_this_time,
max_enc_len_this_time + max_dec_len_this_time,
head_num,
kv_head_num,
false,
moba_use_encoder_seq_limit
)[0];
auto qk_gate_topk_idx = QkSortEncoder(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
max_enc_len_this_time,
max_enc_len_this_time + max_dec_len_this_time,
head_num,
kv_head_num,
moba_encoder_top_k_left,
moba_encoder_top_k_right,
moba_use_mlp && !attn_gate_weight ? max_seq_len : moba_use_encoder_seq_limit)[0];
MobaEncoderAttn(
q_input,
k_input,
v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
out,
max_enc_len_this_time,
max_enc_len_this_time + max_dec_len_this_time,
head_num,
kv_head_num,
head_dim,
max_seq_len
);
}
return {out};
}
PD_BUILD_OP(moba_attention)
.Inputs({
"qkv",
"q_input",
"k_input",
"v_input",
"cu_seq_q",
"cu_seq_k",
"cu_seq_q_pack",
"q_pack_tokens",
"seq_len_encoder",
"seq_len_decoder",
"key_cache",
"value_cache",
"block_tables",
"rope_sin_cos",
"k_block_means",
paddle::Optional("attn_gate_weight"),
paddle::Optional("qkv_bias"),
paddle::Optional("cache_k_quant_scale"),
paddle::Optional("cache_v_quant_scale"),
paddle::Optional("cache_k_dequant_scale"),
paddle::Optional("cache_v_dequant_scale"),
paddle::Optional("cache_k_zero_points"),
paddle::Optional("cache_v_zero_points")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int",
"moba_encoder_top_k_left: int",
"moba_encoder_top_k_right: int",
"moba_use_encoder_seq_limit: int",
"moba_decoder_top_k_left: int",
"moba_decoder_top_k_right: int",
"moba_use_decoder_seq_limit: int",
"moba_use_mlp: bool",
"cache_quant_type_str: std::string"})
.Outputs({
"out",
"q_input_out",
"k_input_out",
"v_input_out",
"key_cache_out",
"value_cache_out",
"k_block_means_out"})
.SetInplaceMap({{
"q_input", "q_input_out"},
{"k_input", "k_input_out"},
{"v_input", "v_input_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"},
{"k_block_means", "k_block_means_out"}})
.SetKernelFn(PD_KERNEL(MobaAttention));

View File

@@ -1,204 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
void MobaDecoderAttnWriteCacheKv(
const paddle::Tensor& qkv_out,
const paddle::Tensor& q_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& rope_sin_cos,
const paddle::Tensor& k_block_means,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const std::string &cache_quant_type_str);
void MobaEncoderAttnWriteCacheKv(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_q,
const std::string &cache_quant_type_str);
void MobaDecoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& k_block_means,
const paddle::Tensor& out,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str);
void FusedBlockMeanAndRope(
const paddle::Tensor& qkv_out,
const paddle::Tensor& k_block_means,
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::optional<paddle::Tensor>& qkv_bias,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str);
std::vector<paddle::Tensor> GetCurCuSeqLenk(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int pack_size);
std::vector<paddle::Tensor> MobaQKGemm(
const paddle::Tensor& q_input,
const paddle::Tensor& k_block_means,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const bool is_split_kv,
const int use_moba_seq_limit);
std::vector<paddle::Tensor> QkSortDecoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit);
void GetKVFromCache(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_k,
const std::string &cache_quant_type_str);
void MobaEncoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& out,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length);
std::vector<paddle::Tensor> QkSortEncoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit);
std::vector<paddle::Tensor> MobaMlpEinsum(
const paddle::Tensor& k_input,
const paddle::Tensor& attn_gate_weight,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seq_k,
const int max_seq_len,
const int kv_head_num);

View File

@@ -1,748 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cub/cub.cuh>
#include "cute/tensor.hpp"
#include "cute/algorithm/copy.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/int_tuple.hpp"
#include <cute/arch/cluster_sm90.hpp>
#include <cub/cub.cuh>
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
using namespace cute;
template<typename T>
struct PackedHalf;
template<>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
};
template<>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
};
template<>
struct PackedHalf<phi::dtype::float16> {
using Type = __half2;
};
template<>
struct PackedHalf<phi::dtype::bfloat16> {
using Type = nv_bfloat162;
};
template<typename T>
struct HalfSub;
template<>
struct HalfSub<cutlass::half_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
}
};
template<>
struct HalfSub<cutlass::bfloat16_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
*reinterpret_cast<nv_bfloat162*>(result_ptr) -= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
}
};
template<typename T>
struct HalfMul;
template<>
struct HalfMul<cutlass::half_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
}
};
template<>
struct HalfMul<cutlass::bfloat16_t> {
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
*reinterpret_cast<nv_bfloat162*>(result_ptr) *= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
}
};
template<typename T>
struct HalfMax;
template<>
struct HalfMax<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("max.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMax<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("max.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<typename T>
struct HalfMin;
template<>
struct HalfMin<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("min.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<>
struct HalfMin<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("min.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
};
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
template<typename T>
struct MinOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
};
template <>
struct MinOp<float> {
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
};
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
template<typename T, bool Is_K>
inline __device__ static void convert_c8_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
uint32_t* half_result_ptr = reinterpret_cast<uint32_t*>(dst);
if constexpr (std::is_same_v<T, cutlass::bfloat16_t>) {
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(*src, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(*src, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(*src, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(*src, fp32_base, 0x7653);
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
fp32_intermediates[ii] -= 8388608.f;
}
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
half_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
}
} else {
static constexpr uint32_t head_for_fp16 = 0x64006400;
half_result_ptr[0] = __byte_perm(*src, head_for_fp16, 0x7150);
half_result_ptr[1] = __byte_perm(*src, head_for_fp16, 0x7352);
}
using pack_half = typename PackedHalf<T>::Type;
#pragma unroll
for (int i = 0; i < 2; i++){
if constexpr (Is_K) {
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + i * 2));
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + i * 2));
} else {
pack_half bias;
pack_half scale;
bias.x = cache_zp[0];
bias.y = cache_zp[0];
scale.x = cache_scale[0];
scale.y = cache_scale[0];
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
}
}
}
template<typename T, bool Is_K>
inline __device__ static void convert_c4_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
using pack_half = typename PackedHalf<T>::Type;
static constexpr uint32_t MASK = 0x0f0f0f0f;
static constexpr uint32_t head_for_fp16 = std::is_same_v<T, cutlass::bfloat16_t> ? 0x43004300 : 0x64006400;
static constexpr uint32_t mask_for_c42fp16_one = 0x7253;
static constexpr uint32_t mask_for_c42fp16_two = 0x7051;
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(dst);
uint32_t source = *reinterpret_cast<uint32_t const*>(src);
// source = {e0 e4 e1 e5 e2 e6 e3 e7}
uint32_t bottom_i4s = source & MASK;
// bottom_i4s = {0 e4 0 e5 0 e6 0 e7}
uint32_t top_i4s = (source >> 4) & MASK;
// top_i4s = {0 e0 0 e1 0 e2 0 e3}
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[0]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
// result_ptr[0] = {e0 e1}
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[1]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[2]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[3]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
#pragma unroll
for (int i = 0; i < 4; ++i) {
if constexpr (Is_K) {
const int ith_col = i % 2 * 2;
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + ith_col));
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + ith_col));
} else {
const int ith_col = i / 2;
pack_half bias;
pack_half scale;
bias.x = cache_zp[ith_col];
bias.y = cache_zp[ith_col];
scale.x = cache_scale[ith_col];
scale.y = cache_scale[ith_col];
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
}
}
}
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
inline __device__ void gemm_qk_quant(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
Tensor4 const& sB, TiledMma tiled_mma,
ThrCopy0 smem_thr_copy_A,
TiledCopy0 smem_tiled_copy_A,
const int32_t tidx,
const T * cache_scale, const T * cache_zp) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
}
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (kDataNumPer2Byte / 4);
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
}
}
if constexpr (kDataNumPer2Byte == 4) {
convert_c4_2_half<T, true>(sBdata + i * kHeadDim, tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
} else {
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2), tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2) + 1, tCrB.data() + 4, cache_scale + i * 4, cache_zp + i * 4);
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
}
}
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
inline __device__ void gemm_value_quant(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
Tensor4 const& sB, TiledMma tiled_mma,
ThrCopy0 smem_thr_copy_A,
TiledCopy0 smem_tiled_copy_A,
int32_t tidx,
const T * cache_scale, const T * cache_zp) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
}
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (2 * kDataNumPer2Byte / 4);
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
const int cur_idx = i * kHeadDim * (2 * kDataNumPer2Byte / 4);
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) {
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
}
}
if constexpr (kDataNumPer2Byte == 4) {
convert_c4_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
convert_c4_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
} else {
convert_c8_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
convert_c8_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 4, cache_scale + 1, cache_zp + 1);
convert_c8_2_half<T, false>(sBdata + cur_idx + 2, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
convert_c8_2_half<T, false>(sBdata + cur_idx + 3, tCrB.data() + 12, cache_scale + 3, cache_zp + 3);
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
}
}
template<int kMiLen, typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &scores, const uint32_t warp_id, const uint32_t col, const uint32_t reamin_seq_len) {
const int cols = size<1>(scores) / 2;
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
#pragma unroll
for (int ni = 0; ni < cols; ++ni) {
const int col_index = warp_id * 8 + ni * 32 + col * 2;
if (col_index >= reamin_seq_len) {
scores(mi, ni * 2) = -INFINITY;
}
if (col_index + 1 >= reamin_seq_len) {
scores(mi, ni * 2 + 1) = -INFINITY;
}
}
}
}
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template<int kMiLen, typename Engine0, typename Layout0, typename T>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, T *scores_max){
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
MaxOp<T> max_op;
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ni++) {
scores_max[mi] = max_op(scores_max[mi], tensor(mi, ni));
}
scores_max[mi] = Allreduce<4>::run(scores_max[mi], max_op);
}
}
template <int kMiLen, typename Engine0, typename Layout0, typename T>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, T const *max, T *sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
const float max_scaled = max[mi] * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled);
sum[mi] += tensor(mi, ni);
}
}
}
template <typename paddle_type>
struct cuteType;
template <>
struct cuteType<phi::dtype::float16> {
using type = cutlass::half_t;
};
template <>
struct cuteType<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
};
template<typename T>
__forceinline__ __device__ auto float_2_half2(const float x) {
if constexpr (std::is_same<T, cutlass::half_t>::value) {
return __float2half2_rn(x);
} else {
return __float2bfloat162_rn(x);
}
}
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
struct uint8 {
uint4 u;
uint4 v;
};
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
inline __device__ Vec() {}
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr) {
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
}
inline __device__ void store_to(void *base_ptr) {
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
}
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
}
}
inline __device__ void set_zero() {
constexpr int size = sizeof(Vec_type) / sizeof(int);
#pragma unroll
for (int i = 0; i < size; ++i) {
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
}
}
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
}
}
};
template<typename T, int PackSize>
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
static_assert(PackSize % 2 == 0);
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
const float cos_inv_freq = cos.data.elt[i];
const float sin_inv_freq = sin.data.elt[i];
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
}
}
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
__forceinline__ __device__ void copy(
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
}
}
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename ThrCopy0, typename ThrCopy1,
typename TiledCopy0, typename TiledCopy1>
inline __device__ void gemm(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
ThrCopy0 &smem_thr_copy_A, ThrCopy1 &smem_thr_copy_B,
TiledCopy0 &smem_tiled_copy_A, TiledCopy1 &smem_tiled_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template<typename T, typename ReductionOp, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) { result_broadcast = result; }
__syncthreads();
return result_broadcast;
}
template<typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
}
};
template<typename T, typename ReductionOp, int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
ReductionOp op;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <int kPackSize, int knthreads>
__device__ inline int get_data_count(const float * src, const float limit_value) {
int count = 0;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (src[i] >= limit_value) {
count++;
}
}
count = BlockAllReduce<int, SumOp<int>, knthreads>(count);
return count;
}

View File

@@ -1,802 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
#include "moba_decoder_attn_kernel.h"
#include "moba_attn/moba_attn.h"
template<bool Is_first, int kMiLen, typename Tensor0, typename Tensor1, typename T>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &acc_o, const T *scores_max, const T *scores_max_prev, T * scores_sum, const float softmax_scale) {
if (Is_first) {
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
} else {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
const float scores_scale = expf((scores_max_prev[mi] - scores_max[mi]) * softmax_scale);
scores_sum[mi] *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale;
}
}
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
}
};
template<typename Kernel_traits, typename ParamType>
__global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attention_kernel(ParamType params) {
using cuteType = typename Kernel_traits::cuteType;
using ElementAccum = typename Kernel_traits::ElementAccum;
using CacheKV_traits = typename Kernel_traits::CacheKV_traits;
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
constexpr int32_t kHeadDimKV = Kernel_traits::kHeadDimKV;
constexpr int32_t kBlockM = Kernel_traits::kBlockM;
constexpr int32_t kBlockSize = Kernel_traits::kBlockSize;
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
constexpr int32_t kNWarps = Kernel_traits::kNWarps;
constexpr int32_t kTileN = Kernel_traits::kTileN;
constexpr int32_t kBlockN = kTileN * kBlockSize;
constexpr int32_t kDataBits = Kernel_traits::kDataBits;
constexpr int32_t kMiLen = (kGqaGroupSize + 7) / 8;
const int32_t bi = blockIdx.y;
const int32_t tidx = threadIdx.x;
const int32_t partition_idx = blockIdx.x;
const int32_t kv_head_idx = blockIdx.z;
const int32_t q_head_idx = kv_head_idx * kGqaGroupSize;
const int32_t seq_len = params.seq_lens_decoder[bi] == 0 ? 0 : params.seq_lens_decoder[bi] + 1;
const int32_t head_num = params.head_num;
const int32_t kv_head_num = params.kv_head_num;
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
if (seq_len == 0 || partition_idx >= partition_num) {
return;
}
if (seq_len >= params.use_moba_seq_limit && params.qk_gate_topk_idx_ptr[(bi * kv_head_num + kv_head_idx) * Kernel_traits::kMaxN + partition_idx] == 0) {
return;
}
const int q_bias_offset = q_head_idx * kHeadDim;
cuteType * q_input = reinterpret_cast<cuteType *>(params.q_input) + params.cu_seq_q[bi] * head_num * kHeadDim;
Tensor gQ = make_tensor(
make_gmem_ptr(reinterpret_cast<const cuteType *>(q_input) + q_bias_offset),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
const int32_t block_idx = partition_idx * kTileN;
const int* block_table = params.block_table + bi * params.max_num_blocks_per_seq + block_idx;
const int32_t physical_block_number = block_table[0];
const int32_t cache_offset = (physical_block_number * kv_head_num + kv_head_idx) * kBlockSize * kHeadDimKV;
Tensor gK = make_tensor(
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_k) + cache_offset),
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
Stride<Int<kHeadDimKV>, _1>{});
Tensor gV = make_tensor(
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_v) + cache_offset),
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
Stride<Int<kHeadDimKV>, _1>{});
extern __shared__ char smem_[];
Tensor sQ = make_tensor(
make_smem_ptr(reinterpret_cast<cuteType *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Tensor sQK = make_tensor(
sQ.data() + size(sQ),
typename Kernel_traits::SmemLayoutQK{});
Tensor sK = make_tensor(sQK.data() + size(sQK), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
__shared__ ElementAccum scores_warp[kNWarps][kMiLen * kBlockM];
auto gmem_tiled_copy_Q = typename Kernel_traits::GmemTiledCopyQ{};
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
auto gmem_tiled_copy_KV = typename Kernel_traits::GmemTiledCopyKV{};
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK);
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV);
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor cKV = make_identity_tensor(make_shape(kBlockSize, kHeadDim));
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
using SmemCopyAtom = typename Kernel_traits::SmemCopyAtom;
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSsQK = smem_thr_copy_Q.partition_S(sQK);
Tensor tSrQK = thr_mma.partition_fragment_A(sQK);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);
copy<false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, kGqaGroupSize);
cute::cp_async_fence();
cp_async_wait<0>();
const int32_t remain_seq_len = seq_len - partition_idx * kTileN * kBlockSize;
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
cute::cp_async_fence();
const int32_t warp_id = tidx / 32;
const int32_t lane_id = tidx % 32;
const int32_t row = lane_id / 4;
const int32_t col = lane_id % 4;
const int row_idx = tidx / 4;
using scale_k_vec = Vec<cuteType, 32>;
using scale_v_vec = Vec<cuteType, 4>;
scale_k_vec scale_k;
scale_k_vec zp_k;
scale_v_vec scale_v;
scale_v_vec zp_v;
if constexpr (kDataBits == 4) {
scale_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_dequant_scale + kv_head_idx * kHeadDim + col * 32);
zp_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_zp + kv_head_idx * kHeadDim + col * 32);
scale_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_dequant_scale + kv_head_idx * kHeadDim + row_idx * 4);
zp_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_zp + kv_head_idx * kHeadDim + row_idx * 4);
}
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
clear(acc_o);
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockSize>>{});
ElementAccum scores_max[kMiLen];
ElementAccum scores_max_prev[kMiLen];
ElementAccum scores_sum[kMiLen];
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_max[mi] = -INFINITY;
scores_sum[mi] = 0;
}
const int cache_offset_step = kv_head_num * kBlockSize * kHeadDimKV;
#pragma unroll
for (int n = 0; n < kTileN; ++n) {
const int cur_remain_seq_len = remain_seq_len - n * kBlockSize;
if (cur_remain_seq_len <= 0) {
break;
}
clear(acc_s);
cp_async_wait<0>();
__syncthreads();
if (n > 0) {
tVgV.data() = tVgV.data() + (block_table[n] - block_table[n - 1]) * cache_offset_step;
}
copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV);
cute::cp_async_fence();
if constexpr (kDataBits == 16) {
if (n == 0) {
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
} else {
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
}
} else {
Tensor tSrKQuant = make_tensor<cuteType>(
Layout<
Shape<Shape<_2, _2>, Int<kBlockSize / 32>>,
Stride<Shape<_1, _2>, _4>>{});
if (n == 0) {
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
} else {
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits, true>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
}
}
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
if (partition_idx == partition_num - 1 && cur_remain_seq_len < kBlockSize) {
apply_mask<kMiLen>(scores, warp_id, col, cur_remain_seq_len);
}
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_max_prev[mi] = scores_max[mi];
}
reduce_max<kMiLen>(scores, scores_max);
if (col == 0) {
scores_warp[warp_id][row] = scores_max[0];
if constexpr (kMiLen > 1) {
scores_warp[warp_id][row + 8] = scores_max[1];
}
}
__syncthreads();
MaxOp<ElementAccum> max_op;
if (tidx < kGqaGroupSize) {
float cur_max = scores_warp[0][tidx];
#pragma unroll
for (uint32_t i = 1; i < kNWarps; ++i) {
cur_max = max_op(scores_warp[i][tidx], cur_max);
}
scores_warp[0][tidx] = cur_max;
}
cp_async_wait<0>();
__syncthreads();
if (cur_remain_seq_len > kBlockSize && n < kTileN - 1) {
tKgK.data() = tKgK.data() + (block_table[n + 1] - block_table[n]) * cache_offset_step;
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
cute::cp_async_fence();
}
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_max[mi] = scores_warp[0][row + mi * 8];
}
if (n == 0) {
softmax_rescale_o<true, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
} else {
softmax_rescale_o<false, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
}
Tensor rS = convert_type<cuteType>(acc_s);
Tensor trQK = smem_thr_copy_O.retile_S(rS);
Tensor tsQK = smem_thr_copy_O.partition_D(sQK);
cute::copy(smem_tiled_copy_O, trQK, tsQK);
__syncthreads();
if constexpr (kDataBits == 16) {
gemm(acc_o, tSrQK, tOrVt, tSsQK, tOsVt, tiled_mma, smem_thr_copy_Q, smem_thr_copy_V, smem_tiled_copy_Q, smem_tiled_copy_V);
} else {
Tensor tSrVQuant = make_tensor<cuteType>(
Layout<
Shape<_4, Shape<_2, _2>>,
Stride<_1, Shape<_4, _8>>>{});
gemm_value_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_o, tSrQK, tSsQK, tSrVQuant, sV, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_v.data.elt, zp_v.data.elt);
}
}
const uint32_t pack_max_partition_num = (params.max_num_partitions + 3) / 4 * 4;
uint32_t max_sum_offset = bi * pack_max_partition_num * head_num + (tidx + q_head_idx) * pack_max_partition_num + partition_idx;
if (tidx < kGqaGroupSize) {
params.maxs[max_sum_offset] = scores_warp[0][tidx] * params.inv_sqrt_dh;
}
SumOp<ElementAccum> sum_op;
#pragma unroll
for (int mi = 0; mi < kMiLen; ++mi) {
scores_sum[mi] = Allreduce<4>::run(scores_sum[mi], sum_op);
}
__syncthreads();
if (col == 0) {
scores_warp[warp_id][row] = scores_sum[0];
if constexpr (kMiLen > 1) {
scores_warp[warp_id][row + 8] = scores_sum[1];
}
}
Tensor rO = convert_type<cuteType>(acc_o);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO);
Tensor taccOsO = smem_thr_copy_O.partition_D(sQ);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
__syncthreads();
if (tidx < kGqaGroupSize) {
float cur_sum = scores_warp[0][tidx];
#pragma unroll
for (uint32_t i = 1; i < kNWarps; ++i) {
cur_sum = sum_op(scores_warp[i][tidx], cur_sum);
}
scores_warp[0][tidx] = cur_sum;
}
Tensor gO = make_tensor(
make_gmem_ptr(reinterpret_cast<cuteType *>(params.partition_attn_out) + ((bi * params.max_num_partitions + partition_idx) * head_num + q_head_idx)* kHeadDim),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
auto gmem_tiled_copy_O = typename Kernel_traits::GmemTiledCopyO{};
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sQ);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
constexpr int32_t copy_size = kGqaGroupSize * 16;
__syncthreads();
if (tidx < copy_size) {
cute::copy(gmem_tiled_copy_O, tOsO(_, 0, _), tOgO(_, 0, _));
}
if constexpr (kMiLen > 1) {
if (tidx < copy_size - 128) {
cute::copy(gmem_tiled_copy_O, tOsO(_, 1, _), tOgO(_, 1, _));
}
}
if (tidx < kGqaGroupSize) {
params.sums[max_sum_offset] = scores_warp[0][tidx];
}
}
template<typename Kernel_traits, typename ParamType>
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType &params, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
const int32_t bi = blockIdx.z;
const int32_t tidx = threadIdx.x;
const int32_t head_idx = blockIdx.y;
const int32_t head_num = params.head_num;
using float_vec = Vec<float, kNFloatPacksize>;
const int32_t offset = bi * head_num * pack_max_partition_num + head_idx * pack_max_partition_num;
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = params.maxs + offset;
float global_max_logit = -FLT_MAX;
int32_t idx = tidx * kNFloatPacksize;
#pragma unroll
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
float_vec cur_max = *reinterpret_cast<const float_vec*>(max_logits_ptr + idx);
#pragma unroll
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
}
} else {
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
}
}
cur_max.store_to(shared_max_logits + idx);
}
const int32_t packed_data_num = partition_num / kNFloatPacksize * kNFloatPacksize;
idx = packed_data_num + tidx;
#pragma unroll
for (; idx < partition_num; idx += kNReduceThreads) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx] != 0) {
float cur_max = max_logits_ptr[idx];
global_max_logit = fmaxf(global_max_logit, cur_max);
shared_max_logits[idx] = cur_max;
}
} else {
float cur_max = max_logits_ptr[idx];
global_max_logit = fmaxf(global_max_logit, cur_max);
shared_max_logits[idx] = cur_max;
}
}
__syncthreads();
global_max_logit = BlockAllReduce<float, MaxOp<float>, kNReduceThreads>(global_max_logit);
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
const float* exp_sums_ptr = params.sums + offset;
float global_exp_sum = 0.0f;
idx = tidx * kNFloatPacksize;
#pragma unroll
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
float_vec share_max = *reinterpret_cast<const float_vec*>(shared_max_logits + idx);
#pragma unroll
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_max.data.elt[j] = exp_sub_max;
}
} else {
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_max.data.elt[j] = exp_sub_max;
}
}
share_max.store_to(share_sum_scale + idx);
}
idx = packed_data_num + tidx;
#pragma unroll
for (; idx < partition_num; idx += kNReduceThreads) {
if (seq_len >= params.use_moba_seq_limit) {
if (qk_gate_topk_idx_ptr[idx] != 0) {
float share_max = shared_max_logits[idx];
float exp_sub_max = expf(share_max - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_sum_scale[idx] = exp_sub_max;
}
} else {
float share_max = shared_max_logits[idx];
float exp_sub_max = expf(share_max - global_max_logit);
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
global_exp_sum += rescaled_exp_sum;
share_sum_scale[idx] = exp_sub_max;
}
}
__syncthreads();
global_exp_sum = BlockAllReduce<float, SumOp<float>, kNReduceThreads>(global_exp_sum);
const float inv_global_exp_sum = fdividef(1.0f, global_exp_sum + 1e-6f);
return inv_global_exp_sum;
}
template<typename Kernel_traits, typename ParamType>
__global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_attention_merge_kernel(ParamType params) {
using cuteType = typename Kernel_traits::cuteType;
constexpr int32_t kBlockN = Kernel_traits::kTileN * Kernel_traits::kBlockSize;
constexpr int32_t kNReducePacksize = 16 / sizeof(cuteType);
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
constexpr int32_t kNReduceWarps = Kernel_traits::kNReduceWarps;
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
const int32_t bi = blockIdx.z;
const int32_t headdim_idx = kNReducePacksize * kNReduceWarps * blockIdx.x;
const int32_t tidx = threadIdx.x;
const int32_t head_idx = blockIdx.y;
const int32_t warp_id = tidx / 32;
const int32_t lane_id = tidx % 32;
const int32_t seq_len = params.seq_lens_decoder[bi] + 1;
const int32_t head_num = params.head_num;
using pack_half = typename PackedHalf<cuteType>::Type;
if (params.seq_lens_decoder[bi] == 0) {
return;
}
extern __shared__ char shared_mem[];
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
const int32_t pack_max_partition_num = (params.max_num_partitions + kNFloatPacksize - 1) / kNFloatPacksize * kNFloatPacksize;
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
using T_vec = Vec<cuteType, kNReducePacksize>;
cuteType* partition_attn_out = reinterpret_cast<cuteType*>(params.partition_attn_out) + bi * head_num * params.max_num_partitions * kHeadDim + head_idx * kHeadDim + headdim_idx;
Vec<float, kNReducePacksize> acc;
acc.set_zero();
#pragma unroll
for (int idx = lane_id; idx < partition_num; idx += 32) {
if (seq_len >= params.use_moba_seq_limit && qk_gate_topk_idx_ptr[idx] == 0) {
continue;
}
T_vec sub_logits = *reinterpret_cast<T_vec*>(&partition_attn_out[idx * head_num * kHeadDim + warp_id * kNReducePacksize]);
float scale = share_sum_scale[idx];
#pragma unroll
for (int k = 0; k < kNReducePacksize; ++k) {
acc.data.elt[k] += static_cast<float>(sub_logits.data.elt[k]) * scale;
}
}
__syncthreads();
T_vec out;
#pragma unroll
for (int k = 0; k < kNReducePacksize; ++k) {
out.data.elt[k] = static_cast<cuteType>(WarpAllReduce<float, SumOp<float>>(acc.data.elt[k]) * inv_global_exp_sum);
}
const int ori_token_idx = params.cu_seq_q[bi];
cuteType * attn_out = reinterpret_cast<cuteType *>(params.attn_out) + ori_token_idx * head_num * kHeadDim + head_idx * kHeadDim + headdim_idx + warp_id * kNReducePacksize;
if (lane_id == 0) {
out.store_to(attn_out);
}
}
template<typename Kernel_traits, typename ParamType>
void run_moba_decoder_attn(ParamType &params, cudaStream_t stream) {
dim3 grid;
grid.x = params.max_num_partitions;
grid.y = params.batch_size;
grid.z = params.kv_head_num;
constexpr int smem_size = Kernel_traits::kShareMemSize;
constexpr auto kernel = &moba_decoder_attention_kernel<Kernel_traits, ParamType>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
int32_t reduce_shared_mem_size = 2 * (params.max_num_partitions + 4) * sizeof(float);
constexpr int32_t pack_size = 16 / sizeof(typename Kernel_traits::cuteType);
static_assert(Kernel_traits::kHeadDim % pack_size == 0);
static_assert((Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps) % pack_size == 0);
grid.x = Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps / pack_size;
grid.y = params.head_num;
grid.z = params.batch_size;
auto reduce_kernel = &moba_decoder_attention_merge_kernel<Kernel_traits, ParamType>;
if (reduce_shared_mem_size >= 48 * 1024) {
cudaFuncSetAttribute(
reduce_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size);
}
reduce_kernel<<<grid, Kernel_traits::kNReduceThreads, reduce_shared_mem_size, stream>>>(params);
}
template<typename cute_type, int kCacheBits, int kBlockN, int kMaxN, typename ParamType>
void run_moba_decoder_attn_hdim128(ParamType &params, cudaStream_t stream) {
const int gqaGroupSize = params.head_num / params.kv_head_num;
using CacheKVTraits = CacheKV_quant_traits<cute_type, kCacheBits>;
constexpr int kTileN = kBlockN / CacheKVTraits::kBlockSize;
switch (gqaGroupSize) {
case 12: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<12, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 8: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<8, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 7: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<7, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 6: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<6, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 5: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<5, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
case 4: {
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<4, kTileN, kMaxN,CacheKVTraits>>(params, stream);
break;
}
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"DecoderBlockAttention not implemented for gqaGroupSize = %d", gqaGroupSize));
}
}
}
template <typename T>
void DispatchMobaDecoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& k_block_means,
const paddle::Tensor& out,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_q,
const int max_seq_k,
const int batch_size,
const int max_input_length,
const int use_moba_seq_limit,
const std::string &cache_quant_type_str) {
using cute_type = typename cuteType<T>::type;
const int kMobaBlockSize = 128;
const int kMaxN = 1024;
constexpr int max_seq_per_block = kMobaBlockSize;
moba_decoder_attn_params<cute_type> params;
memset(&params, 0, sizeof(params));
const uint32_t max_num_partitions = (max_seq_k + max_seq_per_block) / max_seq_per_block;
assert(head_dim == 128);
paddle::Tensor maxs = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
paddle::Tensor sums = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
paddle::Tensor partition_attn_out = paddle::empty({batch_size, max_num_partitions, head_num, head_dim}, q_input.dtype(), q_input.place());
params.q_input = reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>()));
params.attn_out = reinterpret_cast<cute_type *>(const_cast<T*>(out.data<T>()));
params.seq_lens_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_len_decoder.data<int>());
params.block_table = const_cast<int*>(block_tables.data<int>());
params.max_input_length = max_input_length;
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_num_blocks_per_seq = block_tables.dims()[1];
params.batch_size = batch_size;
params.inv_sqrt_dh = 1.0f / std::sqrt(head_dim);
params.max_num_partitions = max_num_partitions;
params.maxs = reinterpret_cast<float*>(maxs.data<float>());
params.sums = reinterpret_cast<float*>(sums.data<float>());
params.partition_attn_out = reinterpret_cast<cute_type *>(partition_attn_out.data<T>());
params.qk_gate_topk_idx_ptr = const_cast<int*>(qk_gate_topk_idx.data<int>());
params.use_moba_seq_limit = use_moba_seq_limit;
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
if (cache_quant_type_str == "none") {
params.cache_k = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k.data<T>()));
params.cache_v = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v.data<T>()));
run_moba_decoder_attn_hdim128<cute_type, 16, max_seq_per_block, kMaxN>(params, q_input.stream());
} else {
params.cache_k = const_cast<uint8_t*>(cache_k.data<uint8_t>());
params.cache_v = const_cast<uint8_t*>(cache_v.data<uint8_t>());
params.cache_k_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_quant_scale.get().data<T>()));
params.cache_v_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_quant_scale.get().data<T>()));
params.cache_k_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_dequant_scale.get().data<T>()));
params.cache_v_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_dequant_scale.get().data<T>()));
params.cache_k_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_zero_points.get().data<T>()));
params.cache_v_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_zero_points.get().data<T>()));
if (cache_quant_type_str == "cache_int8_zp") {
run_moba_decoder_attn_hdim128<cute_type, 8, max_seq_per_block, kMaxN>(params, q_input.stream());
} else if (cache_quant_type_str == "cache_int4_zp") {
run_moba_decoder_attn_hdim128<cute_type, 4, max_seq_per_block, kMaxN>(params, q_input.stream());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"GQA Attention not implemented for cache_quant_type_str = %s", cache_quant_type_str.c_str()));
}
}
}
void MobaDecoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& k_block_means,
const paddle::Tensor& out,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str) {
const int batch_size = block_tables.dims()[0];
if (q_input.dtype() == paddle::DataType::FLOAT16) {
return DispatchMobaDecoderAttn<phi::dtype::float16>(
q_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cache_k,
cache_v,
block_tables,
k_block_means,
out,
qk_gate_topk_idx,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_q,
max_seq_k,
batch_size,
max_input_length,
use_moba_seq_limit,
cache_quant_type_str);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
return DispatchMobaDecoderAttn<phi::dtype::bfloat16>(
q_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cache_k,
cache_v,
block_tables,
k_block_means,
out,
qk_gate_topk_idx,
cache_k_quant_scale,
cache_v_quant_scale,
cache_k_dequant_scale,
cache_v_dequant_scale,
cache_k_zero_points,
cache_v_zero_points,
head_num,
kv_head_num,
head_dim,
max_seq_q,
max_seq_k,
batch_size,
max_input_length,
use_moba_seq_limit,
cache_quant_type_str);
}
}

View File

@@ -1,225 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
#include "cute/tensor.hpp"
#include "cute/algorithm/copy.hpp"
#include "cute/algorithm/gemm.hpp"
#include "../moba_attn_utils.hpp"
using namespace cute;
template <typename T>
struct moba_decoder_attn_params {
T *__restrict__ q_input;
void *__restrict__ cache_k;
void *__restrict__ cache_v;
T *__restrict__ attn_out;
T *__restrict__ partition_attn_out;
T *__restrict__ cache_k_dequant_scale;
T *__restrict__ cache_v_dequant_scale;
T *__restrict__ cache_k_quant_scale;
T *__restrict__ cache_v_quant_scale;
T *__restrict__ cache_k_zp;
T *__restrict__ cache_v_zp;
int * __restrict__ cu_seq_q;
float * sums;
float * maxs;
int * seq_lens_encoder;
int * seq_lens_decoder;
int * block_table;
int max_input_length;
int max_seq_len;
int head_num;
int kv_head_num;
int max_num_blocks_per_seq;
float scale_softmax;
int batch_size;
int max_num_partitions;
float inv_sqrt_dh;
int *qk_gate_topk_idx_ptr;
int use_moba_seq_limit;
};
template <typename cute_type_, int DataBits_>
struct CacheKV_quant_traits {
using cuteType = cute_type_;
static constexpr int kDataBits = DataBits_;
static constexpr int kBlockSize = 64;
static constexpr int kHeadDim = 128;
static constexpr int kBlockKSmem = 64;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<
Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockSize>, Int<kHeadDim>>{}));
static constexpr int kNWarps = 4;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kThreadPerValue = 16 / sizeof(cuteType);
static constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
using GmemLayoutAtom = Layout<
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
Stride<Int<kThreadsPerRow>, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<cuteType, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1,_4,_1>>;
using PermutationMNK = Tile<_16, Int<16 * kNWarps>, _16>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
ValLayoutMNK,
PermutationMNK>;
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, cuteType>;
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<Shape<Int<kBlockKSmem>, Int<kBlockSize>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockSize>>{}));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, cuteType>;
static constexpr int kShareMemSize = size(SmemLayoutKV{}) * 2 * sizeof(cuteType);
};
template <int kGqaGroupSize_, int kTileN_, int kMaxN_, typename CacheKV_traits_>
struct moba_decoder_attn_kernel_traits {
using ElementAccum = float;
using CacheKV_traits = CacheKV_traits_;
using cuteType = typename CacheKV_traits::cuteType;
static constexpr int kDataBits = CacheKV_traits::kDataBits;
static constexpr int kTileN = kTileN_;
static constexpr int kMaxN = kMaxN_;
static constexpr int kGqaGroupSize = kGqaGroupSize_;
static constexpr int kHeadDim = CacheKV_traits::kHeadDim;
static constexpr int kHeadDimKV = kHeadDim / (16 / kDataBits);
static constexpr int kMinGemmM = 16;
static constexpr int kBlockM = (kGqaGroupSize + kMinGemmM - 1) / kMinGemmM * kMinGemmM;
static constexpr int kBlockSize = CacheKV_traits::kBlockSize;
static_assert(kGqaGroupSize <= 16);
static constexpr int32_t kNWarps = CacheKV_traits::kNWarps;
static constexpr int kBlockKSmem = CacheKV_traits::kBlockKSmem;
static constexpr int kBlockKVSmem = kHeadDimKV <= 64 ? kHeadDimKV : 64;
static_assert(kHeadDim % kBlockKSmem == 0);
static constexpr int kNReduceWarps = 4;
static constexpr int kNReduceThreads = kNReduceWarps * 32;
using SmemLayoutAtomQ = typename CacheKV_traits::SmemLayoutAtomQ;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutQK = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kBlockSize>>{}));
using SmemLayoutAtomKV = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<
Shape<Int<8>, Int<kBlockKVSmem>>,
Stride<Int<kBlockKVSmem>, _1>>{}));
using SmemLayoutKV_ = decltype(tile_to_shape(
SmemLayoutAtomKV{},
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{}));
using SmemLayoutKV = std::conditional_t<
kDataBits == 16,
SmemLayoutKV_,
decltype(get_nonswizzle_portion(SmemLayoutKV_{}))
>;
constexpr static int kBlockKVSize = kDataBits == 4 ? 32 : kBlockSize;
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<Shape<Int<kBlockKSmem>, Int<kBlockKVSize>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockKVSize>>{}));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
static constexpr int kThreadsPerRow = CacheKV_traits::kThreadsPerRow;
static constexpr int kThreadsKVPerRow = kThreadsPerRow / (16 / kDataBits);
static constexpr int kNThreads = CacheKV_traits::kNThreads;
using GmemKVLayoutAtom = Layout<
Shape<Int<kNThreads / kThreadsKVPerRow>, Int<kThreadsKVPerRow>>,
Stride<Int<kThreadsKVPerRow>, _1>>;
using SmemCopyAtom = typename CacheKV_traits::SmemCopyAtom;
using TiledMma = typename CacheKV_traits::TiledMma;
static constexpr int kThreadPerValue = CacheKV_traits::kThreadPerValue;
using GmemTiledCopyQ = typename CacheKV_traits::GmemTiledCopyQ;
using GmemLayoutAtom = typename CacheKV_traits::GmemLayoutAtom;
using GmemTiledCopyKV = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
GmemKVLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using SmemCopyAtomTransposed = typename CacheKV_traits::SmemCopyAtomTransposed;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, cuteType>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, cuteType>;
using SmemLayoutAtomO = decltype(
composition(Swizzle<3, 3, 3>{},
Layout<
Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
static constexpr int kShareMemSize = (size(SmemLayoutQ{}) + size(SmemLayoutQK{}) + size(SmemLayoutKV{}) * 2) * sizeof(cuteType);
};

View File

@@ -1,189 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "../moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int kBlockSize, int kHeadDim, int moba_block_size, int kMaxN>
__global__ void moba_decoder_attn_write_c16(
const T * qkv_out,
const T * qkv_bias,
T * q_input,
const int * cu_seq_q,
const int * cu_seq_k,
const int * seq_len_encoder,
const int * seq_len_decoder,
T * cache_k,
T * cache_v,
const int * block_tables,
const float * rope_sin_cos,
T *k_block_means,
const int head_num,
const int kv_head_num,
const int max_blocks_per_seq,
const int max_input_length) {
int bidh = blockIdx.x;
const int bidb = blockIdx.y;
const int tidx = threadIdx.x;
const int seq_len = seq_len_decoder[bidb];
if (seq_len == 0) {
return;
}
constexpr int kPackSize = 4;
using SrcType = Vec<T, kPackSize>;
using rope_type = Vec<float, kPackSize / 2>;
SrcType src, bias, k_prev;
rope_type sin, cos;
const int bias_idx = bidh * kHeadDim + tidx * kPackSize;
const int ori_token_idx = cu_seq_q[bidb];
src.load_from(qkv_out + ori_token_idx * (head_num + 2 * kv_head_num) * kHeadDim + bias_idx);
if (qkv_bias != nullptr) {
bias.load_from(qkv_bias + bias_idx);
src.add(bias);
}
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
const int32_t physical_block_number = block_table_now[seq_len / kBlockSize];
if (bidh < head_num) {
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
src.store_to(q_input + cu_seq_q[bidb] * head_num * kHeadDim + bias_idx);
} else if (bidh < head_num + kv_head_num) {
bidh -= head_num;
const int token_in_blocks = seq_len % kBlockSize;
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
src.store_to(cache);
const int seq_len_block = seq_len / moba_block_size;
const int store_mean_idx = (bidb * kMaxN + seq_len_block) * kv_head_num * kHeadDim + bidh * kHeadDim + tidx * kPackSize;
if (seq_len % moba_block_size != 0) {
const int token_num_prev = seq_len % moba_block_size;
const float inv_tokens_sum = fdividef(1.0f, token_num_prev + 1);
k_prev.load_from(k_block_means + store_mean_idx);
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
src.data.elt[i] = T(inv_tokens_sum * (float(src.data.elt[i]) + float(k_prev.data.elt[i]) * token_num_prev));
}
}
src.store_to(k_block_means + store_mean_idx);
} else {
bidh -= head_num + kv_head_num;
const int token_in_blocks = seq_len % kBlockSize;
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
src.store_to(cache);
}
}
void MobaDecoderAttnWriteCacheKv(
const paddle::Tensor& qkv_out,
const paddle::Tensor& q_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::Tensor& rope_sin_cos,
const paddle::Tensor& k_block_means,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const std::string &cache_quant_type_str) {
constexpr int kThreads = 32;
constexpr int kHeadDim = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
assert(kHeadDim == head_dim);
constexpr int kBlockSize = 64;
const int max_blocks_per_seq = block_tables.dims()[1];
const int batch_size = block_tables.dims()[0];
if (cache_quant_type_str == "none") {
dim3 grid_dims;
grid_dims.x = head_num + kv_head_num * 2;
grid_dims.y = batch_size;
if (qkv_out.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
qkv_out.data<T>(),
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
const_cast<T*>(q_input.data<T>()),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T *>(cache_k.data<T>()),
const_cast<T *>(cache_v.data<T>()),
block_tables.data<int>(),
rope_sin_cos.data<float>(),
const_cast<T*>(k_block_means.data<T>()),
head_num,
kv_head_num,
max_blocks_per_seq,
max_input_length);
} else if (qkv_out.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
qkv_out.data<T>(),
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
const_cast<T*>(q_input.data<T>()),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T *>(cache_k.data<T>()),
const_cast<T *>(cache_v.data<T>()),
block_tables.data<int>(),
rope_sin_cos.data<float>(),
const_cast<T*>(k_block_means.data<T>()),
head_num,
kv_head_num,
max_blocks_per_seq,
max_input_length);
}
} else {
PD_THROW("Only supported cache_quant_type_str in ['none'].");
}
}

View File

@@ -1,236 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int knthreads, int moba_block_size, int kBlockMaxN, int searchtimes>
__global__ void qk_gate_sort_decoder_kernel(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *decoder_seq_lens,
const int head_num,
const int kv_head_num,
const int kGqaGroupSize,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
const int bidb = blockIdx.x;
const int bidh = blockIdx.y;
const int tidx = threadIdx.x;
const int bidh_kv = bidh / kGqaGroupSize;
if (decoder_seq_lens[bidb] == 0 || decoder_seq_lens[bidb] < use_moba_seq_limit) {
return;
}
const int seq_len = (decoder_seq_lens[bidb] + moba_block_size - 1) / moba_block_size;
constexpr int kPackSize = kBlockMaxN / knthreads;
static_assert(kBlockMaxN % knthreads == 0);
T token_mean[kPackSize];
using SrcType = Vec<T, kPackSize>;
using SrcType_f = Vec<float, kPackSize>;
using SrcType_i = Vec<int, kPackSize>;
SrcType src;
SrcType_f src_f;
SrcType_i select_idx;
select_idx.set_zero();
const int load_offset = bidb * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
src.load_from(qk_gate_weight + load_offset);
float max_global = -FLT_MAX;
float min_global = FLT_MAX;
const int data_len = seq_len - tidx * kPackSize;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (i < data_len) {
src_f.data.elt[i] = float(src.data.elt[i]);
min_global = min(min_global, src_f.data.elt[i]);
} else {
src_f.data.elt[i] = -FLT_MAX;
}
max_global = max(max_global, src_f.data.elt[i]);
}
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
float right_limit = max_global;
float left_limit = min_global;
float mid_limit;
int count;
#pragma unroll
for (int i = 0; i < searchtimes; i++) {
mid_limit = (left_limit + right_limit) * 0.5f;
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
if (count < top_k_left) {
right_limit = mid_limit;
} else if (count > top_k_right) {
left_limit = mid_limit;
} else {
break;
}
}
const int store_idx = bidb * kv_head_num * kBlockMaxN + bidh_kv * kBlockMaxN + tidx * kPackSize;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (src_f.data.elt[i] >= mid_limit) {
qk_gate_topk_idx[store_idx + i] = 1;
}
}
if (tidx == 0) {
qk_gate_topk_idx[store_idx] = 1;
qk_gate_topk_idx[store_idx + seq_len - 1] = 1;
qk_gate_topk_idx[store_idx + seq_len - 2] = 1;
}
}
template <int kBlockMaxN, int moba_block_size, typename T>
void qk_gate_sort_decoder(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *decoder_seq_lens,
const int head_num,
const int kv_head_num,
const int batch_size,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit,
cudaStream_t stream) {
const int gqa_group_size = head_num / kv_head_num;
constexpr int kPackSize = 16 / sizeof(T);
const int knthreads = kBlockMaxN / kPackSize;
dim3 grid_dims;
grid_dims.x = batch_size;
grid_dims.y = head_num;
const int searchtimes = 6;
constexpr auto kernel = qk_gate_sort_decoder_kernel<T, knthreads, moba_block_size, kBlockMaxN, searchtimes>;
kernel<<<grid_dims, knthreads, 0, 0>>>(
qk_gate_weight,
qk_gate_topk_idx,
decoder_seq_lens,
head_num,
kv_head_num,
gqa_group_size,
top_k_left,
top_k_right,
use_moba_seq_limit);
}
template <typename T>
std::vector<paddle::Tensor> DispatchQkSortDecoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
const int batch_size = seq_len_decoder.dims()[0];
paddle::Tensor qk_gate_topk_idx = paddle::empty({batch_size, kv_head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
qk_gate_sort_decoder<kMaxN, kMobaBlockSize, T>(
qk_gate_weight.data<T>(),
qk_gate_topk_idx.data<int>(),
seq_len_decoder.data<int>(),
head_num,
kv_head_num,
batch_size,
top_k_left,
top_k_right,
use_moba_seq_limit,
qk_gate_weight.stream()
);
return {qk_gate_topk_idx};
}
std::vector<paddle::Tensor> QkSortDecoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
return std::move(
DispatchQkSortDecoder<phi::dtype::float16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit)
);
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
return std::move(
DispatchQkSortDecoder<phi::dtype::bfloat16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit)
);
}
}
PD_BUILD_OP(moba_qk_sort_decoder)
.Inputs({
"qk_gate_weight",
"seq_len_encoder",
"seq_len_decoder"})
.Attrs({
"head_num: int",
"kv_head_num: int",
"top_k_left: int",
"top_k_right: int",
"use_moba_seq_limit: int"})
.Outputs({"qk_gate_topk_idx"})
.SetKernelFn(PD_KERNEL(QkSortDecoder));

View File

@@ -1,143 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
using namespace cute;
struct moba_encoder_attn_params {
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void * __restrict__ o_ptr;
int * __restrict__ cu_seq_q;
int * __restrict__ cu_seq_k;
int * __restrict__ qk_gate_topk_idx;
int * __restrict__ seq_len_encoder;
int * __restrict__ cu_seq_q_pack;
int head_num;
int kv_head_num;
int max_seq_q;
int max_seq_k;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;
int use_moba_seq_limit;
};
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
struct SharedStorageQKVO {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, int kMaxN_, bool UseMoba_, typename elem_type=cutlass::half_t>
struct moba_encoder_attn_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int32_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int UseMoba = UseMoba_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static constexpr int kMaxN = kMaxN_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
static constexpr int kStages = kStages_;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using GmemTiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{},
TiledCopyOThrLayout{}, // Thr layout
TiledCopyOValLayout{} // Val layout
));
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
};

View File

@@ -1,473 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
using namespace cute;
enum class AttnNamedBarriers {
QueryEmpty = 0,
ValueEmpty = 1,
TileCountSmemEmpty = 2,
TileCountSmemFull = 3,
WarpSchedulerWG1 = 4,
WarpSchedulerWG2 = 5,
WarpSchedulerWG3 = 6,
};
template <typename Ktraits>
struct CollectiveMainloopAttn {
using Element = typename Ktraits::Element;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
using SmemLayoutO = typename Ktraits::SmemLayoutO;
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
LayoutT layout_Q;
Element const* ptr_K;
LayoutT layout_K;
Element const* ptr_V;
LayoutT layout_V;
float const softmax_scale_log2;
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for Q
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.layout_Q, args.layout_K, args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
const MTensor &m_tensor,
const Shape &tile_shape,
const int *cu_seq_len,
const int bidh,
const int bidb,
const int actual_seq_len) const {
auto g_offset = local_tile(
m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb], _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()
));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <bool UseMoba, typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
const int *qk_gate_topk_idx,
const int n_block_max,
const int m_block,
const int bidh,
const int bidb,
const int *cu_seq_q,
const int *cu_seq_k,
const int seq_len_q,
const int seq_len_k) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
Tensor gQ = get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
Tensor gK = get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor gV = get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
uint16_t mcast_mask_kv = 0;
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
}
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
if (lane_predicate) {
int idx = 0;
#pragma unroll 2
for (; n_block > 0; ) {
pipeline_k.producer_acquire(smem_pipe_write_k);
int pre_idx = 1;
if constexpr (UseMoba) {
pre_idx = qk_gate_topk_idx[idx];
}
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - pre_idx), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
n_block -= pre_idx;
idx += 1;
}
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
CUTLASS_DEVICE void
warp_scheduler_barrier_sync() {
if constexpr (UseSchedulerBarrier) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
}
}
CUTLASS_DEVICE void
mma_init() {
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if (cutlass::canonical_warp_group_idx() > 1) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
}
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
if (cutlass::canonical_warp_group_idx() > 2) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
}
}
}
CUTLASS_DEVICE void
warp_scheduler_barrier_arrive() {
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
}
}
template <bool UseMoba, typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
const int *qk_gate_topk_idx,
const int n_block_max,
const int thread_idx,
const int m_block,
const int seq_len_q,
const int seq_len_k,
SharedStorage& shared_storage) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int n_block = n_block_max - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
warp_scheduler_barrier_sync();
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warp_scheduler_barrier_arrive();
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
};
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >=
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
int idx = 0;
#pragma unroll 2
for (; n_block > 0; ) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
warp_scheduler_barrier_sync();
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
warp_scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
if constexpr (UseMoba) {
n_block -= qk_gate_topk_idx[idx];
idx += 1;
} else {
n_block -= 1;
}
}
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v);
++smem_pipe_read_v;
softmax.rescale_o(tOrO, scores_scale);
}
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO const& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
const int o_head_stride,
const int real_seq,
T * out_ptr) {
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::ValueEmpty) /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(o_head_stride, _1{}));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
if (real_seq >= kBlockM) {
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
} else {
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
}
}
};

View File

@@ -1,384 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
#include "kernel_traits.h"
#include "mainloop_attn.hpp"
#include "softmax.hpp"
#include "cutlass/arch/reg_reconfig.h"
template <int kHeadDim>
auto get_gmem_layout(int token_num, int head_num) {
return make_layout(
make_shape(token_num, kHeadDim, head_num),
make_stride(head_num * kHeadDim, _1{}, kHeadDim));
}
template <typename Ktraits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
moba_encoder_attention_kernel(
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
CUTE_GRID_CONSTANT moba_encoder_attn_params const data_params) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
constexpr int kHeadDim = Ktraits::kHeadDim;
constexpr int kMaxN = Ktraits::kMaxN;
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int seq_len_q = data_params.seq_len_encoder[bidb];
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
if (seq_len_q == 0) {
return;
}
__align__(16) __shared__ int qk_gate_topk_idx[kMaxN];
const int *qk_gate_idx_cur_offset = data_params.qk_gate_topk_idx + data_params.cu_seq_q_pack[bidb] / kBlockM * data_params.head_num * kMaxN + (m_block * data_params.head_num + bidh) * kMaxN;
#pragma unroll
for (int i = threadIdx.x; i < kMaxN / 4; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
reinterpret_cast<int4*>(qk_gate_topk_idx)[i] = reinterpret_cast<const int4*>(qk_gate_idx_cur_offset)[i];
}
const int n_block_max = min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN), cute::ceil_div(seq_len_k, kBlockN));
if (m_block * kBlockM >= seq_len_q) {
return;
}
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1);
}
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
__syncthreads();
CollectiveMainloop collective_mainloop;
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) {
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load<Ktraits::UseMoba>(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_write_k,
smem_pipe_write_v,
shared_storage,
qk_gate_topk_idx,
n_block_max,
m_block,
bidh,
bidb,
data_params.cu_seq_q,
data_params.cu_seq_k,
seq_len_q,
seq_len_k);
}
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
typename Ktraits::TiledMma1 tiled_mma1;
collective_mainloop.mma_init();
PipelineState smem_pipe_read_k, smem_pipe_read_v;
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
collective_mainloop.mma<Ktraits::UseMoba>(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_read_k,
smem_pipe_read_v,
tOrO,
softmax,
qk_gate_topk_idx,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
seq_len_q,
seq_len_k,
shared_storage);
const int o_head_stride = data_params.head_num * kHeadDim;
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
const int real_seq = seq_len_q - m_block * kBlockM;
collective_mainloop.store<NumMmaThreads>(
mainloop_params,
tOrO,
shared_storage,
tiled_mma1,
threadIdx.x - NumCopyThreads,
o_head_stride,
real_seq,
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
}
}
template<typename Kernel_traits>
void run_moba_decoder_attn(moba_encoder_attn_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_q * params.batch_size, params.head_num),
static_cast<Element const*>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
static_cast<Element const*>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
params.scale_softmax_log2
});
int num_blocks_m = cutlass::ceil_div(params.max_seq_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
void *kernel;
kernel = (void *)moba_encoder_attention_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = num_blocks_m;
grid_dims.y = params.head_num;
grid_dims.z = params.batch_size;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
}
template <int kBlockM, int kBlockN, int kMaxN, typename InputType>
void run_moba_encoder_attn_hdim128(moba_encoder_attn_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
constexpr static int kNWarps = kBlockM / 16 + 4;
constexpr static int kStages = 2;
using Ktraits = moba_encoder_attn_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, kMaxN, true, InputType>;
run_moba_decoder_attn<Ktraits>(params, stream);
}
template <typename T>
void DispatchMobaEncoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& out,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int head_dim,
const int batch_size,
const int max_input_length) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
using cute_type = typename cuteType<T>::type;
moba_encoder_attn_params params;
memset(&params, 0, sizeof(moba_encoder_attn_params));
params.q_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>()));
params.k_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>()));
params.v_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>()));
params.o_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(out.data<T>()));
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_q = max_seq_q;
params.max_seq_k = max_seq_k;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
params.qk_gate_topk_idx = const_cast<int*>(qk_gate_topk_idx.data<int>());
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.cu_seq_q_pack = const_cast<int*>(cu_seq_q_pack.data<int>());
run_moba_encoder_attn_hdim128<kBlockM, kBlockN, kMaxN, cute_type>(params, out.stream());
}
void MobaEncoderAttn(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& qk_gate_topk_idx,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& out,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length) {
const int batch_size = seq_len_encoder.dims()[0];
if (q_input.dtype() == paddle::DataType::FLOAT16) {
return
DispatchMobaEncoderAttn<phi::dtype::float16>(
q_input,
k_input,
v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
out,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
head_dim,
batch_size,
max_input_length);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
return
DispatchMobaEncoderAttn<phi::dtype::bfloat16>(
q_input,
k_input,
v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
out,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
head_dim,
batch_size,
max_input_length);
}
}
PD_BUILD_OP(moba_encoder_attn)
.Inputs({
"q_input",
"k_input",
"v_input",
"qk_gate_topk_idx",
"cu_seq_q",
"cu_seq_k",
"cu_seq_q_pack",
"seq_len_encoder",
"seq_len_decoder",
"out"})
.Attrs({
"max_seq_q: int",
"max_seq_k: int",
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_input_length: int"})
.Outputs({"attn_out"})
.SetInplaceMap({{"out", "attn_out"}})
.SetKernelFn(PD_KERNEL(MobaEncoderAttn));

View File

@@ -1,163 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn.h"
template <typename T, int kBlockSize, int kHeadDim>
__global__ void write_encoder_cachekv_c16(
const T * k_input,
const T * v_input,
const int * cu_seq_k,
const int * seq_len_encoder,
const int * seq_len_decoder,
T * cache_k,
T * cache_v,
const int * block_tables,
const int kv_head_num,
const int max_blocks_per_seq) {
constexpr int kPackSize = 16 / sizeof(T);
const int block_idx = blockIdx.x * kBlockSize;
int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
const int row_idx = tidx / (kHeadDim / kPackSize);
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
const int seq_len = seq_len_encoder[bidb];
if (seq_len == 0) return;
const int ramian_tokens = seq_len - block_idx;
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
if (bidh < kv_head_num) {
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
}
}
} else {
bidh -= kv_head_num;
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
}
}
}
}
void MobaEncoderAttnWriteCacheKv(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_q,
const std::string &cache_quant_type_str) {
constexpr int kThreads = 128;
constexpr int kHeadDim = 128;
assert(kHeadDim == head_dim);
constexpr int kBlockSize = 64;
const int batch_size = block_tables.dims()[0];
const int max_blocks_per_seq = block_tables.dims()[1];
if (cache_quant_type_str == "none") {
dim3 grid_dims;
grid_dims.x = (max_seq_q + kBlockSize - 1) / kBlockSize;
grid_dims.y = kv_head_num * 2;
grid_dims.z = batch_size;
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(v_input.data<T>()),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T*>(cache_k.data<T>()),
const_cast<T*>(cache_v.data<T>()),
block_tables.data<int>(),
kv_head_num,
max_blocks_per_seq);
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(v_input.data<T>()),
cu_seq_k.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
const_cast<T*>(cache_k.data<T>()),
const_cast<T*>(cache_v.data<T>()),
block_tables.data<int>(),
kv_head_num,
max_blocks_per_seq);
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Quantized cache not implemented for cache_quant_type = %s", cache_quant_type_str.c_str()));
}
}
PD_BUILD_OP(moba_encoder_attn_write_cache_kv)
.Inputs({
"k_input",
"v_input",
"cu_seq_k",
"seq_len_encoder",
"seq_len_decoder",
"cache_k",
"cache_v",
"block_tables",
paddle::Optional("cache_k_quant_scale"),
paddle::Optional("cache_v_quant_scale"),
paddle::Optional("cache_k_dequant_scale"),
paddle::Optional("cache_v_dequant_scale"),
paddle::Optional("cache_k_zero_points"),
paddle::Optional("cache_v_zero_points")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_q: int",
"cache_quant_type_str: std::string"})
.Outputs({"cache_k_out", "cache_v_out"})
.SetInplaceMap({{"cache_k", "cache_k_out"},
{"cache_v", "cache_v_out"}})
.SetKernelFn(PD_KERNEL(MobaEncoderAttnWriteCacheKv));

View File

@@ -1,341 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
template <typename T, int knthreads, int moba_block_size, int kBlockM, int kBlockMaxN, int searchtimes>
__global__ void qk_gate_sort_encoder_kernel(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int* cu_seq_q,
const int* cu_seq_k,
const int* cu_seq_q_pack,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int kGqaGroupSize,
const int top_k_left,
const int top_k_right) {
const int bidt = blockIdx.x * kBlockM;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
constexpr int kPackSize = kBlockMaxN / knthreads;
static_assert(kBlockMaxN % knthreads == 0);
const int seq_len_q = seq_len_encoder[bidb];
if (seq_len_q == 0 || bidt >= seq_len_q) {
return;
}
const int seq_len_k = (bidt + kBlockM + seq_len_decoder[bidb]);
const int seq_len_moba = seq_len_k / moba_block_size;
using SrcType = Vec<T, kPackSize>;
using SrcType_f = Vec<float, kPackSize>;
using SrcType_i = Vec<int, kPackSize>;
SrcType src;
SrcType_f src_f;
SrcType_i select_idx;
select_idx.set_zero();
const int store_idx = cu_seq_q_pack[bidb] / kBlockM * head_num * kBlockMaxN + bidh * kBlockMaxN + blockIdx.x * head_num * kBlockMaxN + tidx * kPackSize;
if (seq_len_k < use_moba_seq_limit) {
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
select_idx.data.elt[i] = 1;
}
select_idx.store_to(qk_gate_topk_idx + store_idx);
return;
}
const int load_offset = (cu_seq_q[bidb] + bidt) * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
const int data_len = seq_len_moba - tidx * kPackSize;
#pragma unroll
for (int t = 0; t < kBlockM; t++) {
if (bidt + t >= seq_len_q) {
break;
}
src.load_from(qk_gate_weight + load_offset + t * head_num * kBlockMaxN);
float min_global = FLT_MAX;
float max_global = -FLT_MAX;
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (i < data_len) {
src_f.data.elt[i] = float(src.data.elt[i]);
min_global = min(min_global, src_f.data.elt[i]);
} else {
src_f.data.elt[i] = -FLT_MAX;
}
max_global = max(max_global, src_f.data.elt[i]);
}
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
float right_limit = max_global;
float left_limit = min_global;
float mid_limit;
int count;
if (right_limit == left_limit) {
mid_limit = (left_limit + right_limit) * 0.5f;
} else {
#pragma unroll
for (int i = 0; i < searchtimes; i++) {
mid_limit = (left_limit + right_limit) * 0.5f;
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
if (count < top_k_left) {
right_limit = mid_limit;
} else if (count > top_k_right) {
left_limit = mid_limit;
} else {
break;
}
}
}
#pragma unroll
for (int i = 0; i < kPackSize; i++) {
if (src_f.data.elt[i] >= mid_limit) {
select_idx.data.elt[i] = 1;
}
}
}
if (tidx == 0) {
select_idx.data.elt[0] = 1;
}
__align__(16) __shared__ int qk_gate_mem[kBlockMaxN];
__align__(16) __shared__ int qk_continue_idx_mem[kBlockMaxN];
select_idx.store_to(qk_gate_mem + tidx * kPackSize);
__syncthreads();
if (tidx == 0) {
int cur_idx = 0;
int idx = -1;
const int last_idx = seq_len_moba - 1;
while (last_idx + idx >= 0 && qk_gate_mem[last_idx + idx] == 0) {
idx--;
}
qk_continue_idx_mem[cur_idx] = -idx;
cur_idx++;
for (int i = last_idx - 1; i >= 0; --i) {
if (qk_gate_mem[i] == 1) {
int idx = -1;
while (i + idx >= 0 && qk_gate_mem[i + idx] == 0) {
idx--;
}
qk_continue_idx_mem[cur_idx] = -idx;
cur_idx++;
}
}
qk_continue_idx_mem[cur_idx] = 10000000;
}
__syncthreads();
*reinterpret_cast<SrcType_i *>(qk_gate_topk_idx + store_idx) = reinterpret_cast<SrcType_i *>(qk_continue_idx_mem)[tidx];
}
template <int kBlockM, int kMaxN, int moba_block_size, typename T>
void qk_gate_sort_encoder(
const T* qk_gate_weight,
int * qk_gate_topk_idx,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int* cu_seq_q,
const int* cu_seq_k,
const int* cu_seq_q_pack,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int batch_size,
const int top_k_left,
const int top_k_right,
cudaStream_t stream) {
constexpr int kPackSize = 16 / sizeof(T);
const int gqa_group_size = head_num / kv_head_num;
const int knthreads = kMaxN / kPackSize;
const int searchtimes = 6;
dim3 grid_dims;
grid_dims.x = (max_seq_q + kBlockM - 1) / kBlockM;
grid_dims.y = head_num;
grid_dims.z = batch_size;
constexpr auto kernel = qk_gate_sort_encoder_kernel<T, knthreads, moba_block_size, kBlockM, kMaxN, searchtimes>;
kernel<<<grid_dims, knthreads, 0, stream>>>(
qk_gate_weight,
qk_gate_topk_idx,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
gqa_group_size,
top_k_left,
top_k_right);
}
template <typename T>
std::vector<paddle::Tensor> DispatchQkSortEncoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
using cute_type = typename cuteType<T>::type;
const int batch_size = seq_len_encoder.dims()[0];
paddle::Tensor qk_gate_topk_idx = paddle::empty({q_pack_tokens.data<int>()[0] / kBlockM, head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
qk_gate_sort_encoder<kBlockM, kMaxN, kMobaBlockSize, cute_type>(
reinterpret_cast<const cute_type *>(qk_gate_weight.data<T>()),
qk_gate_topk_idx.data<int>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
cu_seq_q_pack.data<int>(),
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
batch_size,
top_k_left,
top_k_right,
qk_gate_weight.stream());
return {qk_gate_topk_idx};
}
std::vector<paddle::Tensor> QkSortEncoder(
const paddle::Tensor& qk_gate_weight,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& cu_seq_q_pack,
const paddle::Tensor& q_pack_tokens,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int top_k_left,
const int top_k_right,
const int use_moba_seq_limit) {
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
return std::move(
DispatchQkSortEncoder<phi::dtype::float16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit
)
);
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
return std::move(
DispatchQkSortEncoder<phi::dtype::bfloat16>(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
top_k_left,
top_k_right,
use_moba_seq_limit
)
);
}
}
PD_BUILD_OP(moba_qk_sort_encoder)
.Inputs({
"qk_gate_weight",
"seq_len_encoder",
"seq_len_decoder",
"cu_seq_q",
"cu_seq_k",
"cu_seq_q_pack",
"q_pack_tokens"})
.Attrs({
"max_seq_q: int",
"max_seq_k: int",
"head_num: int",
"kv_head_num: int",
"top_k_left: int",
"top_k_right: int",
"use_moba_seq_limit: int"})
.Outputs({"qk_gate_topk_idx"})
.SetKernelFn(PD_KERNEL(QkSortEncoder));

View File

@@ -1,194 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "../moba_attn_utils.hpp"
using namespace cute;
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
}
__forceinline__ __device__ __half2 half_exp(__half2 x) {
uint32_t tmp_out, tmp_in;
tmp_in = reinterpret_cast<uint32_t&>(x);
asm ("ex2.approx.f16x2 %0, %1;\n"
: "=r"(tmp_out)
: "r"(tmp_in));
__half2 out = reinterpret_cast<__half2&>(tmp_out);
return out;
}
// Apply the exp to all the elements.
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const float max_scaled = max(mi) * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
CUTLASS_DEVICE Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
cute::fill(scores_scale, 1.f);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*zero_init=*/false>(scores, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale(mi);
}
}
return scores_scale;
};
template<bool Is_first, typename Tensor0>
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f);
} else {
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
}
return scores_scale;
};
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT scores_scale;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.0f / sum;
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
scores_scale(mi) = inv_sum;
}
return scores_scale;
};
template<typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
};

View File

@@ -1,288 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int kBlockSize, int kHeadDim>
__global__ void get_kv_from_cache_c16_kernel(
T * k_input,
T * v_input,
const int * seq_len_encoder,
const int * seq_len_decoder,
const int * cu_seq_k,
const T * cache_k,
const T * cache_v,
const int * block_tables,
const int kv_head_num,
const int head_dim,
const int batch_size,
const int max_input_length,
const int max_blocks_per_seq) {
const int block_idx = blockIdx.x;
int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int seq_len = seq_len_decoder[bidb] + seq_len_encoder[bidb];
const int tidx = threadIdx.x;
const int base_token_idx = block_idx * kBlockSize;
if (base_token_idx >= seq_len || seq_len_encoder[bidb] == 0) {
return;
}
constexpr int kPackSize = 16 / sizeof(T);
const int row_idx = tidx / (kHeadDim / kPackSize);
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
const int ramian_tokens = seq_len - base_token_idx;
if (bidh < kv_head_num) {
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
}
}
} else {
bidh -= kv_head_num;
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
#pragma unroll
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
if (i < ramian_tokens) {
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
}
}
}
}
template <typename T>
void get_kv_from_cache(
T * k_input,
T * v_input,
const int * seq_len_encoder,
const int * seq_len_decoder,
const int * cu_seq_k,
const void * cache_k,
const void * cache_v,
const int * block_tables,
const T * cache_k_dequant_scale,
const T * cache_v_dequant_scale,
const T * cache_k_zero_points,
const T * cache_v_zero_points,
const int kv_head_num,
const int head_dim,
const int max_seq_k,
const int batch_size,
const int max_input_length,
const int max_blocks_per_seq,
const std::string &cache_quant_type_str,
cudaStream_t stream) {
constexpr int kThreads = 128;
constexpr int kHeadDim = 128;
assert(kHeadDim == head_dim);
constexpr int kBlockSize = 64;
if (cache_quant_type_str == "none") {
dim3 grid_dims;
grid_dims.x = (max_seq_k + kBlockSize - 1) / kBlockSize;
grid_dims.y = kv_head_num * 2;
grid_dims.z = batch_size;
get_kv_from_cache_c16_kernel<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, stream>>>(
k_input,
v_input,
seq_len_encoder,
seq_len_decoder,
cu_seq_k,
reinterpret_cast<const T*>(cache_k),
reinterpret_cast<const T*>(cache_v),
block_tables,
kv_head_num,
head_dim,
batch_size,
max_input_length,
max_blocks_per_seq);
} else {
PD_THROW("Only supported cache_quant_type_str in ['none'].");
}
}
void GetKVFromCache(
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cache_k,
const paddle::Tensor& cache_v,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_k,
const std::string &cache_quant_type_str) {
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
using cute_type = typename cuteType<T>::type;
get_kv_from_cache<cute_type>(
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_k.data<int>(),
cache_k.data(),
cache_v.data(),
block_tables.data<int>(),
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
kv_head_num,
head_dim,
max_seq_k,
seq_len_encoder.dims()[0],
max_input_length,
block_tables.dims()[1],
cache_quant_type_str,
k_input.stream());
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
using cute_type = typename cuteType<T>::type;
get_kv_from_cache<cute_type>(
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_k.data<int>(),
cache_k.data(),
cache_v.data(),
block_tables.data<int>(),
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
kv_head_num,
head_dim,
max_seq_k,
seq_len_encoder.dims()[0],
max_input_length,
block_tables.dims()[1],
cache_quant_type_str,
k_input.stream());
}
}
__global__ void get_cur_cu_seq_len_k_kernel(
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
int* __restrict__ cu_seqlens_k,
int* __restrict__ cu_seq_q_pack,
int* __restrict__ q_pack_tokens,
const int pack_size,
const int bsz) {
int total_tokens = 0;
cu_seqlens_k[0] = 0;
cu_seq_q_pack[0] = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int cache_len = seq_lens_decoder[bid];
const int q_len = seq_lens_encoder[bid];
if (q_len <= 0) {
cache_len = 0;
}
total_tokens += (cache_len + q_len);
cu_seqlens_k[bid + 1] = total_tokens;
cu_seq_q_pack[bid + 1] = cu_seq_q_pack[bid] + (q_len + pack_size -1) / pack_size * pack_size;
}
q_pack_tokens[0] = cu_seq_q_pack[bsz];
}
std::vector<paddle::Tensor> GetCurCuSeqLenk(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int pack_size) {
auto stream = seq_lens_decoder.stream();
auto place = seq_lens_decoder.place();
int bsz = seq_lens_this_time.shape()[0];
paddle::Tensor cu_seq_q_pack = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
paddle::Tensor cu_seqlens_k = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
paddle::Tensor q_pack_tokens = paddle::empty({1}, paddle::DataType::INT32, place);
get_cur_cu_seq_len_k_kernel<<<1, 1, 0, stream>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seqlens_k.data<int>(),
cu_seq_q_pack.data<int>(),
q_pack_tokens.data<int>(),
pack_size,
bsz
);
auto q_pack_tokens_cpu = q_pack_tokens.copy_to(paddle::CPUPlace(), true);
return {cu_seq_q_pack, cu_seqlens_k, q_pack_tokens_cpu};
}
PD_BUILD_OP(get_kv_from_cache)
.Inputs({
"k_input",
"v_input",
"cu_seq_k",
"seq_len_encoder",
"seq_len_decoder",
"cache_k",
"cache_v",
"block_tables",
paddle::Optional("cache_k_dequant_scale"),
paddle::Optional("cache_v_dequant_scale"),
paddle::Optional("cache_k_zero_points"),
paddle::Optional("cache_v_zero_points")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_input_length: int",
"max_seq_k: int",
"cache_quant_type_str: std::string"})
.Outputs({"k_input_out", "v_input_out"})
.SetInplaceMap({{"k_input", "k_input_out"},
{"v_input", "v_input_out"}})
.SetKernelFn(PD_KERNEL(GetKVFromCache));
PD_BUILD_OP(get_cur_cu_seq_len_k)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time"})
.Attrs({
"pack_size: int"})
.Outputs({"cu_seq_q_pack", "cu_seqlens_k", "q_pack_tokens"})
.SetKernelFn(PD_KERNEL(GetCurCuSeqLenk));

View File

@@ -1,221 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename T, int moba_block_size, int kHeadDim, int kMaxN>
__global__ void moba_mlp_einsum_kernel(
const T * src_data,
const T * weight_data,
const int * seq_lens_encoder,
const int * seq_lens_decoder,
const int * cu_seq_k,
T * dst_data,
const int head_num) {
constexpr int kPackSize = 16 / sizeof(T);
const int block_idx = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
const int lane_id = tidx % 32;
const int warp_id = tidx / 32;
__align__(16) __shared__ T local_sum_mem[128 / 32 * kHeadDim];
const int seq_len_encoder = seq_lens_encoder[bidb];
const int seq_len_decoder = seq_len_encoder + seq_lens_decoder[bidb];
const int seq_len_this_block = seq_len_decoder - block_idx * moba_block_size;
if (seq_len_encoder == 0 || seq_len_this_block <= 0) {
return;
}
using SrcType = Vec<T, kPackSize>;
constexpr int tidx_per_row = kHeadDim / kPackSize;
const int row_idx = tidx / tidx_per_row;
const int col_idx = tidx % tidx_per_row * kPackSize;
const int src_base_idx = cu_seq_k[bidb] * head_num * kHeadDim + block_idx * moba_block_size * head_num * kHeadDim + bidh * kHeadDim + row_idx * head_num * kHeadDim + col_idx;
const int weight_base_idx = bidh * kHeadDim * moba_block_size + row_idx * kHeadDim + col_idx;
constexpr int step = 128 / tidx_per_row;
SrcType sums, src, weight;
sums.set_zero();
for (int i = 0; i < moba_block_size; i += step) {
if (i >= seq_len_this_block) {
break;
}
src.load_from(src_data + src_base_idx + i * head_num * kHeadDim);
weight.load_from(weight_data + weight_base_idx + i * kHeadDim);
sums.fma(src, weight);
}
SrcType neighbor;
#pragma unroll
for (int i = 0; i < kPackSize; i+=2) {
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(sums.data.elt + i), 16);
}
sums.add(neighbor);
if (lane_id < 16) {
sums.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
}
__syncthreads();
using pack_half = std::conditional_t<std::is_same<T, phi::dtype::float16>::value, __half2, nv_bfloat162>;
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
if (tidx < kHeadDim / 2) {
pack_half local_sum_half = local_sum_mem_half[tidx];
#pragma unroll
for (int i = 1; i < 4; i++) {
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
}
local_sum_mem_half[tidx] = local_sum_half;
}
__syncthreads();
const int store_row_id = tidx / (kHeadDim / kPackSize);
const int store_col_id = tidx % (kHeadDim / kPackSize) * kPackSize;
sums.load_from(local_sum_mem + store_col_id);
const int base_store_idx = bidb * kMaxN * head_num * kHeadDim + (block_idx * (moba_block_size / 128) + store_row_id) * head_num * kHeadDim + bidh * kHeadDim + store_col_id;
if (store_row_id < moba_block_size / 128) {
sums.store_to(dst_data + base_store_idx);
}
}
template <typename T, int kHeadDim, int kMaxN>
void moba_mlp_einsum(
const T * src_data,
const T * weight_data,
const int * seq_lens_encoder,
const int * seq_lens_decoder,
const int * cu_seq_k,
T * dst_data,
const int moba_block_size,
const int max_seq_len,
const int head_num,
const int batch_size,
cudaStream_t stream) {
dim3 grid_dims;
grid_dims.x = (max_seq_len + moba_block_size - 1) / moba_block_size;
grid_dims.y = head_num;
grid_dims.z = batch_size;
if (moba_block_size == 1024) {
moba_mlp_einsum_kernel<T, 1024, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
src_data,
weight_data,
seq_lens_encoder,
seq_lens_decoder,
cu_seq_k,
dst_data,
head_num);
} else if (moba_block_size == 128) {
moba_mlp_einsum_kernel<T, 128, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
src_data,
weight_data,
seq_lens_encoder,
seq_lens_decoder,
cu_seq_k,
dst_data,
head_num);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"MobaMlpEinsum not implemented for moba_block_size = %d", moba_block_size));
}
}
std::vector<paddle::Tensor> MobaMlpEinsum(
const paddle::Tensor& k_input,
const paddle::Tensor& attn_gate_weight,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seq_k,
const int max_seq_len,
const int kv_head_num) {
const int kHeadDim = 128;
const int kMaxN = 1024;
const int moba_block_size = attn_gate_weight.dims()[1];
const int batch_size = seq_lens_encoder.dims()[0];
paddle::Tensor k_gate_weight = paddle::zeros({batch_size, kMaxN, kv_head_num, kHeadDim}, k_input.dtype(), k_input.place());
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
moba_mlp_einsum<T, kHeadDim, kMaxN>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(attn_gate_weight.data<T>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int*>(cu_seq_k.data<int>()),
k_gate_weight.data<T>(),
moba_block_size,
max_seq_len,
kv_head_num,
batch_size,
k_input.stream()
);
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
moba_mlp_einsum<T, kHeadDim, kMaxN>(
const_cast<T*>(k_input.data<T>()),
const_cast<T*>(attn_gate_weight.data<T>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int*>(cu_seq_k.data<int>()),
k_gate_weight.data<T>(),
moba_block_size,
max_seq_len,
kv_head_num,
batch_size,
k_input.stream()
);
}
return {k_gate_weight};
}
PD_BUILD_OP(moba_mlp_einsum)
.Inputs({
"k_input",
"attn_gate_weight",
"seq_lens_encoder",
"seq_lens_decoder",
"cu_seq_k"})
.Attrs({
"max_seq_len: int",
"kv_head_num: int"})
.Outputs({"k_gate"})
.SetKernelFn(PD_KERNEL(MobaMlpEinsum));

View File

@@ -1,465 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, int kHeadDim, bool is_split_kv>
__global__ void qk_gemm_kernel(
const input_type *q_input,
const input_type *k_gate_mean,
input_type *qk_gate_weight,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int kGQA_groupsize) {
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using SmemLayoutAtomQ = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, input_type,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, input_type, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
using SmemLayoutAtomQK = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, input_type,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutQK = decltype(tile_to_shape(SmemLayoutAtomQK{}, select<0, 1>(TileShape_MNK{})));
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<input_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = std::conditional_t<
is_split_kv,
Layout<Shape<_1,_4,_1>>,
Layout<Shape<_4,_1,_1>>
>;
using PermutationMNK = std::conditional_t<
is_split_kv,
Tile<_16,_64,_16>,
Tile<_64,_16,_16>
>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
ValLayoutMNK,
PermutationMNK>;
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, input_type>;
using SmemCopyAtomQK = Copy_Atom<cute::SM90_U32x4_STSM_N, input_type>;
constexpr int kNThreads = 128;
constexpr int kThreadPerValue = 16 / sizeof(input_type);
constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
constexpr int kThreadsPerRowQK = kBlockN / kThreadPerValue;
using GmemLayoutAtom = Layout<
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
Stride<Int<kThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, input_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
using GmemLayoutAtomQK = Layout<
Shape <Int<kNThreads / kThreadsPerRowQK>, Int<kThreadsPerRowQK>>,
Stride<Int<kThreadsPerRowQK>, _1>>;
using GmemTiledCopyQK = decltype(
make_tiled_copy(Copy_Atom<
UniversalCopy<cutlass::uint128_t>, input_type>{},
GmemLayoutAtomQK{},
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
int mn_block = blockIdx.x;
const int bidb = blockIdx.y;
const int bidh = blockIdx.z;
const int bidh_k = bidh / kGQA_groupsize;
const int tidx = threadIdx.x;
extern __shared__ char smem_[];
const int seq_len_q = seq_len_encoder[bidb];
const int seq_len_k = seq_len_decoder[bidb];
const int seq_len_qk = seq_len_q + seq_len_k;
int q_head_stride;
const int k_head_stride = kv_head_num * kHeadDim;
int qk_head_stride;
int offset_q;
int offset_k;
int offset_qk;
int remain_q_seq;
if constexpr (is_split_kv) {
if (seq_len_k < use_moba_seq_limit || seq_len_k == 0) {
return;
}
mn_block *= kBlockN;
q_head_stride = kHeadDim;
qk_head_stride = kMaxN;
if (mn_block >= (seq_len_k + kMobaBlockSize - 1) / kMobaBlockSize) {
return;
}
offset_q = cu_seq_q[bidb] * head_num * kHeadDim + bidh * kGQA_groupsize * kHeadDim;
offset_k = (bidb * kMaxN + mn_block) * k_head_stride + bidh * kHeadDim;
offset_qk = bidb * head_num * kMaxN + bidh * kGQA_groupsize * kMaxN + mn_block;
remain_q_seq = kGQA_groupsize;
} else {
if (seq_len_q == 0 || seq_len_qk < use_moba_seq_limit) {
return;
}
q_head_stride = head_num * kHeadDim;
qk_head_stride = head_num * kMaxN;
mn_block *= kBlockM;
if (mn_block >= seq_len_q) {
return;
}
offset_q = (cu_seq_q[bidb] + mn_block) * q_head_stride + bidh * kHeadDim;
offset_k = bidb * kMaxN * k_head_stride + bidh_k * kHeadDim;
offset_qk = (cu_seq_q[bidb] + mn_block) * qk_head_stride + bidh * kMaxN;
remain_q_seq = seq_len_q - mn_block;
}
Tensor gQ = make_tensor(make_gmem_ptr(q_input + offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(q_head_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(k_gate_mean + offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(k_head_stride, _1{}));
Tensor gQK = make_tensor(make_gmem_ptr(qk_gate_weight + offset_qk),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(qk_head_stride, _1{}));
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<input_type *>(smem_)), SmemLayoutK{});
Tensor sQ = make_tensor(sK.data() + size(sK), SmemLayoutQ{});
Tensor sQK = make_tensor(sK.data() + size(sK), SmemLayoutQK{});
auto gmem_tiled_copy = GmemTiledCopy{};
auto gmem_tiled_copy_qk = GmemTiledCopyQK{};
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx);
auto gmem_thr_copy_qk = gmem_tiled_copy_qk.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy.partition_D(sQ);
Tensor tKgK = gmem_thr_copy.partition_S(gK);
Tensor tKsK = gmem_thr_copy.partition_D(sK);
Tensor tQKgQK = gmem_thr_copy_qk.partition_S(gQK);
Tensor tQKsQK = gmem_thr_copy_qk.partition_D(sQK);
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
Tensor tQcQ = gmem_thr_copy.partition_S(cQ);
Tensor cK = make_identity_tensor(make_shape(kBlockN, kHeadDim));
Tensor tKcK = gmem_thr_copy.partition_S(cK);
Tensor cQK = make_identity_tensor(make_shape(kBlockM, kBlockN));
Tensor tQKcQK = gmem_thr_copy.partition_S(cQK);
if (remain_q_seq >= kBlockM) {
copy(gmem_tiled_copy, tQgQ, tQsQ, tQcQ);
} else {
copy<false>(gmem_tiled_copy, tQgQ, tQsQ, tQcQ, remain_q_seq);
}
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
cute::cp_async_fence();
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_QK = make_tiled_copy_C(SmemCopyAtomQK{}, tiled_mma);
auto smem_thr_copy_QK = smem_tiled_copy_QK.get_thread_slice(tidx);
Tensor tsQK = smem_thr_copy_QK.partition_D(sQK);
const int n_blocks = is_split_kv ? 1 : cute::ceil_div(cute::ceil_div(seq_len_qk, kMobaBlockSize), kBlockN);
#pragma unroll
for (int n_block = 0; n_block < n_blocks; ++n_block) {
clear(acc_s);
cp_async_wait<0>();
__syncthreads();
if (n_block == 0) {
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
} else {
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
}
if constexpr (!is_split_kv) {
if (n_block < n_blocks - 1) {
__syncthreads();
tKgK.data() = tKgK.data() + kBlockN * k_head_stride;
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
cute::cp_async_fence();
}
}
Tensor rS = convert_type<input_type>(acc_s);
Tensor trQK = smem_thr_copy_QK.retile_S(rS);
cute::copy(smem_tiled_copy_QK, trQK, tsQK);
__syncthreads();
if (remain_q_seq >= kBlockM) {
copy(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK);
} else {
copy<false>(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK, remain_q_seq);
}
if constexpr (!is_split_kv) {
__syncthreads();
tQKgQK.data() = tQKgQK.data() + kBlockN;
}
}
}
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, bool is_split_kv>
void qk_gemm(
const input_type *q_input,
const input_type *k_gate_mean,
input_type *qk_gate_weight,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int use_moba_seq_limit,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int bsz,
cudaStream_t stream) {
const int gqa_group_size = head_num / kv_head_num;
dim3 grid_dims;
const int num_m_block = (max_seq_q + kBlockM - 1) / kBlockM;
const int num_n_block = ((max_seq_k + kMobaBlockSize - 1) / kMobaBlockSize + kBlockN - 1) / kBlockN;
if (is_split_kv) {
grid_dims.x = num_n_block;
grid_dims.z = kv_head_num;
} else {
grid_dims.x = num_m_block;
grid_dims.z = head_num;
}
grid_dims.y = bsz;
constexpr int kHeadDim = 128;
constexpr int smemq = kBlockM * kHeadDim * sizeof(input_type);
constexpr int smemk = kBlockN * kHeadDim * sizeof(input_type);
constexpr int smemqk = kBlockM * kBlockN * sizeof(input_type);
const int smem_size = smemk + max(smemq, smemqk);
auto kernel = &qk_gemm_kernel<input_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, kHeadDim, is_split_kv>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid_dims, 128, smem_size, stream>>>(
q_input,
k_gate_mean,
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
gqa_group_size);
}
template <typename T>
std::vector<paddle::Tensor> DispatchMobaQKGemm(
const paddle::Tensor& q_input,
const paddle::Tensor& k_block_means,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const bool is_split_kv,
const int use_moba_seq_limit) {
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
const int batch_size = seq_len_encoder.dims()[0];
using cute_type = typename cuteType<T>::type;
if (is_split_kv) {
paddle::Tensor qk_gate_weight = paddle::empty({batch_size, head_num, kMaxN}, q_input.dtype(), q_input.place());
qk_gemm<cute_type, 16, kMobaBlockSize, kMobaBlockSize, kMaxN, true>(
reinterpret_cast<const cute_type*>(q_input.data<T>()),
reinterpret_cast<const cute_type*>(k_block_means.data<T>()),
reinterpret_cast<cute_type*>(qk_gate_weight.data<T>()),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
batch_size,
q_input.stream()
);
return {qk_gate_weight};
} else {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int token_num = q_input.dims()[0];
paddle::Tensor qk_gate_weight = paddle::empty({token_num, head_num, kMaxN}, q_input.dtype(), q_input.place());
qk_gemm<cute_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, false>(
reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>())),
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
reinterpret_cast<cute_type *>(qk_gate_weight.data<T>()),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
use_moba_seq_limit,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
batch_size,
q_input.stream());
return {qk_gate_weight};
}
}
std::vector<paddle::Tensor> MobaQKGemm(
const paddle::Tensor& q_input,
const paddle::Tensor& k_block_means,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const bool is_split_kv,
const int use_moba_seq_limit) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
return std::move(
DispatchMobaQKGemm<phi::dtype::float16>(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
is_split_kv,
use_moba_seq_limit
)
);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
return std::move(
DispatchMobaQKGemm<phi::dtype::bfloat16>(
q_input,
k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
is_split_kv,
use_moba_seq_limit
)
);
}
}
PD_BUILD_OP(moba_qk_gemm)
.Inputs({
"q_input",
"k_block_means",
"seq_len_encoder",
"seq_len_decoder",
"cu_seq_q",
"cu_seq_k"})
.Attrs({
"max_seq_q: int",
"max_seq_k: int",
"head_num: int",
"kv_head_num: int",
"is_split_kv: bool",
"use_moba_seq_limit: int"})
.Outputs({"qk_gate_weight"})
.SetKernelFn(PD_KERNEL(MobaQKGemm));

View File

@@ -1,370 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "moba_attn/moba_attn_utils.hpp"
#include "moba_attn/moba_attn.h"
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN, int tokens_per_block, bool need_k_mean>
__global__ void fused_block_mean_and_rope_kernel(
const input_type *qkv_input,
const input_type *qkv_bias,
input_type *k_gate_mean,
input_type *q_input,
input_type *k_input,
input_type *v_input,
const float *rope_sin_cos,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int max_input_length) {
constexpr int kPackSize = 16 / sizeof(input_type);
constexpr int kHeadDim = 128;
using src_type = Vec<input_type, kPackSize>;
using rope_type = Vec<float, kPackSize / 2>;
using pack_half = std::conditional_t<std::is_same<input_type, cutlass::half_t>::value, __half2, nv_bfloat162>;
__align__(16) __shared__ input_type local_sum_mem[128 / 32 * kHeadDim];
const int bidb = blockIdx.x;
const int bidh = blockIdx.y;
const int bidt_q = blockIdx.z * tokens_per_block;
const int bidt_v = blockIdx.z * tokens_per_block;
const int bidt_k = need_k_mean ? blockIdx.z * moba_block_size : blockIdx.z * tokens_per_block;
const int tidx = threadIdx.x;
const int lane_id = tidx % 32;
const int warp_id = tidx / 32;
const int seq_len = seq_len_encoder[bidb];
const int seq_len_start = seq_len_decoder[bidb];
if (seq_len == 0) {
return;
}
const int all_head_num = head_num + 2 * kv_head_num;
const int hidden = all_head_num * kHeadDim;
const int row_idx = tidx / (kHeadDim / kPackSize);
const int col_idx = tidx % (kHeadDim / kPackSize);
const int bias_idx = bidh * kHeadDim + col_idx * kPackSize;
src_type src, src_bias;
rope_type sin, cos;
const bool need_add_bias = qkv_bias != nullptr;
if (need_add_bias) {
src_bias.load_from(qkv_bias + bias_idx);
}
if (bidh < head_num) {
const int cur_token = bidt_q + row_idx;
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
if (cur_token < seq_len) {
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
src.store_to(q_input + (cu_seq_q[bidb] + cur_token) * head_num * kHeadDim + bias_idx);
}
} else if (bidh < head_num + kv_head_num) {
if constexpr (!need_k_mean) {
const int cur_token = bidt_k + row_idx;
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
if (cur_token < seq_len) {
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * head_num * kHeadDim + bias_idx- head_num * kHeadDim);
}
} else {
if (bidt_k >= seq_len) {
return;
}
src_type local_sum;
local_sum.set_zero();
const input_type* qkv = qkv_input + cu_seq_q[bidb] * hidden + bias_idx;
for (int i = 0; i < moba_block_size; i += tokens_per_block) {
const int cur_token = bidt_k + i + row_idx;
if (cur_token < seq_len) {
src.load_from(qkv + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
sin.load_from(sin_rope);
cos.load_from(cos_rope);
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - head_num * kHeadDim);
local_sum.add(src);
}
}
src_type neighbor;
#pragma unroll
for (int i = 0; i < kPackSize; i+=2) {
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(local_sum.data.elt + i), 16);
}
local_sum.add(neighbor);
if (lane_id < 16) {
local_sum.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
}
__syncthreads();
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
pack_half local_sum_half = local_sum_mem_half[tidx];
if (tidx < kHeadDim / 2) {
#pragma unroll
for (int i = 1; i < 4; i++) {
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
}
float inv_tokens_sum = fdividef(1.0f, min(seq_len - bidt_k, moba_block_size));
local_sum_half *= float_2_half2<input_type>(inv_tokens_sum);
const int store_mean_idx = ((bidb * kMaxN + blockIdx.z + seq_len_start / moba_block_size) * kv_head_num * kHeadDim + (bidh - head_num) * kHeadDim) / 2 + tidx;
reinterpret_cast<pack_half*>(k_gate_mean)[store_mean_idx] = local_sum_half;
}
}
} else {
const int cur_token = bidt_v + row_idx;
if (cur_token < seq_len) {
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
if (need_add_bias) {
src.add(src_bias);
}
src.store_to(v_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - (head_num + kv_head_num) * kHeadDim);
}
}
}
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN>
void fused_block_mean_and_rope(
const input_type *qkv_input,
const input_type *qkv_bias,
input_type *k_gate_mean,
input_type *q_input,
input_type *k_input,
input_type *v_input,
const float *rope_sin_cos,
const int *seq_len_encoder,
const int *seq_len_decoder,
const int *cu_seq_q,
const int *cu_seq_k,
const int max_seq_q,
const int max_seq_k,
const int head_num,
const int kv_head_num,
const int bsz,
const int max_input_length,
cudaStream_t stream) {
static_assert(moba_block_size >= 64, "moba_block_size must be at least 64");
constexpr int kPackSize = 16 / sizeof(input_type);
constexpr int kHeadDim = 128;
constexpr int kThreads = 128;
constexpr int tokens_per_block = kThreads / (kHeadDim / kPackSize);
dim3 grid_dims;
grid_dims.x = bsz;
grid_dims.y = head_num + 2 * kv_head_num;
grid_dims.z = (max_seq_q + tokens_per_block - 1) / tokens_per_block;
if (k_gate_mean != nullptr) {
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, true>
<<<grid_dims, kThreads, 0, stream>>>(
qkv_input,
qkv_bias,
k_gate_mean,
q_input,
k_input,
v_input,
rope_sin_cos,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
max_input_length);
} else {
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, false>
<<<grid_dims, kThreads, 0, stream>>>(
qkv_input,
qkv_bias,
k_gate_mean,
q_input,
k_input,
v_input,
rope_sin_cos,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
max_input_length);
}
}
void FusedBlockMeanAndRope(
const paddle::Tensor& qkv_out,
const paddle::Tensor& k_block_means,
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& seq_len_decoder,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::optional<paddle::Tensor>& qkv_bias,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_input_length,
const int max_seq_q,
const int max_seq_k,
const std::string &cache_quant_type_str) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kMobaBlockSize = 128;
constexpr int kMaxN = 1024;
if (k_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
using cute_type = typename cuteType<T>::type;
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
rotary_embs.data<float>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
seq_len_encoder.dims()[0],
max_input_length,
qkv_out.stream());
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
using cute_type = typename cuteType<T>::type;
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
rotary_embs.data<float>(),
seq_len_encoder.data<int>(),
seq_len_decoder.data<int>(),
cu_seq_q.data<int>(),
cu_seq_k.data<int>(),
max_seq_q,
max_seq_k,
head_num,
kv_head_num,
seq_len_encoder.dims()[0],
max_input_length,
qkv_out.stream());
}
}
PD_BUILD_OP(fused_block_mean_and_rope)
.Inputs({
"qkv_out",
"k_block_means",
"q_input",
"k_input",
"v_input",
"rotary_embs",
"seq_len_encoder",
"seq_len_decoder",
"cu_seq_q",
"cu_seq_k",
paddle::Optional("qkv_bias")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_input_length: int",
"max_seq_q: int",
"max_seq_k: int",
"cache_quant_type_str: std::string"})
.Outputs({"q_input_out", "k_input_out", "v_input_out", "k_block_means_out"})
.SetInplaceMap({{"q_input", "q_input_out"},
{"k_input", "k_input_out"},
{"v_input", "v_input_out"},
{"k_block_means", "k_block_means_out"}})
.SetKernelFn(PD_KERNEL(FusedBlockMeanAndRope));

View File

@@ -28,16 +28,6 @@
#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \
switch (num_experts_per_rank) { \
case 2: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 2; \
__VA_ARGS__ \
break; \
} \
case 6: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
__VA_ARGS__ \
@@ -53,16 +43,6 @@
__VA_ARGS__ \
break; \
} \
case 32: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
__VA_ARGS__ \
break; \
} \
case 48: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
__VA_ARGS__ \
@@ -329,7 +309,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
}
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int Kthread = 512, int RoundType = 1>
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int RoundType = 1>
__global__ void permute_x_kernel(const T *src_x,
const int64_t *topk_idx,
const float *topk_weights,
@@ -345,9 +325,9 @@ __global__ void permute_x_kernel(const T *src_x,
int *dst_indices,
int *cumsum_idx_gpu,
int64_t *token_nums_per_expert_cumsum,
int64_t *expert_idx_per_token, // [num_rows, moe_topk]
int64_t *expert_idx_per_token,
float max_bound = 127.0,
float min_bound = -127.0) {
float min_bound = -127.0) { // [num_rows, moe_topk]
const int src_token_idx = blockIdx.x;
const int tid = threadIdx.x;
constexpr int vec_size = sizeof(int4) / sizeof(T);
@@ -390,17 +370,10 @@ __global__ void permute_x_kernel(const T *src_x,
if (up_gate_proj_in_scale) {
for (int i = 0; i < vec_size; i++) {
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
if constexpr (std::is_same<OutT, int8_t>::value) {
// w4aint8
if (RoundType == 0) {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
} else {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(round(quant_value), min_bound, max_bound));
}
if (RoundType == 0) {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
} else {
// w4afp8
float value = ClipFunc<float>(quant_value, min_bound, max_bound);
res_vec[i] = static_cast<OutT>(value);
res_vec[i] = static_cast<OutT>(round(quant_value));
}
}
} else {
@@ -440,10 +413,6 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
auto stream = input.stream();
auto place = input.place();
const int gridx = min(132 * 8, num_rows);
@@ -491,50 +460,6 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
-127.0
);
}
} else if (moe_quant_type == "w4afp8") {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
}
} else {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
@@ -608,7 +533,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
auto permute_input = GetEmptyTensor(
{token_nums_this_rank, hidden_size},
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type,
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type,
place);
auto num_experts_per_rank_tensor = GetEmptyTensor(
{num_experts_per_rank},

View File

@@ -88,7 +88,7 @@ struct nv_type_traits<int8_t> {
constexpr int kLogN = 7; \
__VA_ARGS__ \
} else { \
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \
}
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
@@ -108,7 +108,7 @@ struct nv_type_traits<int8_t> {
constexpr int VEC_SIZE = 1; \
__VA_ARGS__ \
} else { \
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \
}
#define DISPATCH_logN(logN, kLogN, ...) \
@@ -605,6 +605,26 @@ void moe_fast_hardamard_kernel(const T *x,
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
}
if constexpr (kNChunks > 1) {
// T x_vals_transposed[VecSize][kNChunks] = {init_value};
// #pragma unroll
// for (int c = 0; c < kNChunks; ++c) {
// #pragma unroll
// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
// }
// if constexpr (kNChunks == 28) {
// hadamard_mult_thread_chunk_28<VecSize>(x_vals_transposed);
// } else if constexpr (kNChunks == 36) {
// hadamard_mult_thread_chunk_36<VecSize>(x_vals_transposed);
// } else {
// constexpr int kLogNChunks = cilog2(kNChunks);
// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
// hadamard_mult_thread<kLogNChunks, VecSize>(x_vals_transposed);
// }
// #pragma unroll
// for (int c = 0; c < kNChunks; ++c) {
// #pragma unroll
// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
// }
if constexpr (kNChunks == 28) {
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
} else if constexpr (kNChunks == 36) {

View File

@@ -72,287 +72,6 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)
return u;
}
struct uint8 {
uint4 u;
uint4 v;
};
template<int BYTES> struct BytesToType {};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
};
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * static_cast<float>(input);
return static_cast<OutType>(ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename T, typename OutT, int VecSize, int Kthread>
__global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs,
const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
OutT* out) {
using LoadT = AlignedVector<T, VecSize>;
using LoadOutT = AlignedVector<OutT, VecSize>;
LoadT input_vec;
LoadOutT output_vec;
float scale_factor = -7.0f / 512.0f;
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x;
continue;
}
int64_t expert_idx = expert_idx_per_token[token_idx];
float quant_scale = quant_scales[expert_idx];
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
thread_row_sum += static_cast<float>(output_vec[i]);
}
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, typename OutT, int VecSize, int Kthread>
__global__ void quantize_moe_input_kernel(const T* permuted_inputs,
const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
OutT* out) {
using LoadT = AlignedVector<T, VecSize>;
using LoadOutT = AlignedVector<OutT, VecSize>;
LoadT input_vec;
LoadOutT output_vec;
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
int64_t expert_idx = expert_idx_per_token[token_idx];
float quant_scale = quant_scales[expert_idx];
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
thread_row_sum += static_cast<float>(output_vec[i]);
}
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, typename OutT>
void quantize_moe_input(
const T* permuted_inputs,
const int64_t* expert_idx_per_token,
const float* quant_scales,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
OutT* out,
cudaStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
constexpr int threads_per_block = 128;
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
assert(dim % VecSize == 0);
auto kernel = used_in_ep_low_latency ? masked_quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block> : quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, threads_per_block, 0);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
kernel<<<grid, threads_per_block, 0, stream>>>(
permuted_inputs,
expert_idx_per_token,
quant_scales,
quant_max_bound,
quant_min_bound,
token_num,
dim,
permuted_input_row_sum,
recv_expert_count,
num_max_tokens_per_expert,
out);
}
template <typename T, int VecSize, int Kthread>
__global__ void masked_compute_row_sum_kernel(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert) {
using LoadT = AlignedVector<T, VecSize>;
LoadT input_vec;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
const auto expert_id = token_idx / num_max_tokens_per_expert;
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
token_idx += num_iters_to_next_expert * gridDim.x;
continue;
}
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_row_sum += static_cast<float>(input_vec[i]);
}
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T, int VecSize, int Kthread>
__global__ void compute_row_sum_kernel(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert) {
using LoadT = AlignedVector<T, VecSize>;
LoadT input_vec;
float scale_factor = -7.0f / 512.0f;
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
float thread_row_sum = 0.0f;
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
int64_t offset = token_idx * dim + idx * VecSize;
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_row_sum += static_cast<float>(input_vec[i]);
}
}
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
}
}
template <typename T>
void compute_row_sum(
const T* permuted_inputs,
const int64_t token_num,
const int64_t dim,
float* permuted_input_row_sum,
const int64_t* recv_expert_count,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
cudaStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
constexpr int threads_per_block = 128;
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
assert(dim % VecSize == 0);
auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel<T, VecSize, threads_per_block> : compute_row_sum_kernel<T, VecSize, threads_per_block>;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, threads_per_block, 0);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
dim3 grid;
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
kernel<<<grid, threads_per_block, 0, stream>>>(
permuted_inputs,
token_num,
dim,
permuted_input_row_sum,
recv_expert_count,
num_max_tokens_per_expert);
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing
// the output in the softmax kernel when we extend this module to support
@@ -431,61 +150,8 @@ __launch_bounds__(TPB) __global__
}
}
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
T* output,
const int64_t num_cols,
const int64_t num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val =
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax,
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
T* output,
IdxT* indices,
int* source_rows,
@@ -542,7 +208,60 @@ __launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_sof
}
}
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
template <typename T, int TPB>
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
T* output,
const int64_t num_cols,
const int64_t num_rows) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val =
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = T(val);
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const T* bias,
T* output,
@@ -565,13 +284,153 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
T weight_sum = static_cast<T>(0);
T* row_outputs = nullptr;
if constexpr (NormWeights){
extern __shared__ char smem[];
row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
// softmax
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
if (globalIdx >= num_rows) {
return;
}
const int64_t thread_row_offset = globalIdx * num_experts;
const int64_t idx = thread_row_offset+threadIdx.x;
cub::Sum sum;
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0) {
float_max = maxElem;
}
__syncthreads();
float threadDataSub = threadData - float_max;
float threadDataExp = exp(threadDataSub);
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
if (prior_winning_expert == expert) {
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp =
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
}
__syncthreads();
}
}
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
const T* bias,
T* output,
IdxT* indices,
int* source_rows,
const int64_t num_experts,
const int64_t k,
const int64_t num_rows) {
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
if (block_row >= num_rows) {
return;
}
const bool should_process_row = true;
const int thread_read_offset = block_row * num_experts;
T weight_sum = static_cast<T>(0);
extern __shared__ char smem[];
T* row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
@@ -598,32 +457,28 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0) {
const int idx = k * block_row + k_idx;
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
if constexpr (NormWeights){
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
else{
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
}
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
__syncthreads();
}
if constexpr (NormWeights){
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
}
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
template <typename T, int TPB, typename IdxT = int>
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
const T* bias,
T* output,
IdxT* indices,
@@ -677,11 +532,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
cub::ArgMax arg_max;
T weight_sum = static_cast<T>(0);
T* row_outputs = nullptr;
if constexpr (NormWeights){
extern __shared__ char smem[];
row_outputs = reinterpret_cast<T*>(smem);
}
extern __shared__ char smem[];
T* row_outputs = reinterpret_cast<T*>(smem);
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
@@ -708,28 +560,22 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
if (threadIdx.x == 0) {
const int cur_idx = k * globalIdx + k_idx;
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
indices[cur_idx] = result_kvp.key;
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
if constexpr (NormWeights) {
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
row_outputs[k_idx] = row_out;
weight_sum += row_out;
}
else {
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
}
}
__syncthreads();
}
if constexpr (NormWeights) {
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
if (threadIdx.x < WARP_SIZE) {
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
}
if (threadIdx.x < k) {
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
}
}
@@ -851,11 +697,9 @@ template <typename T,
int NUM_EXPERTS,
int WARPS_PER_CTA,
int BYTES_PER_LDG,
bool Norm_Weights = false,
typename IdxT = int>
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
void topk_gating_softmax(const T* input,
const T* bias,
T* output,
const int64_t num_rows,
IdxT* indices,
@@ -911,7 +755,6 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
const int thread_row_in_cta = thread_row - cta_base_row;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows) return;
@@ -927,9 +770,6 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
T weight_sum = static_cast<T>(0);
extern __shared__ T row_output[];
// Determine the pointer type to use to read in the data depending on the
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
// up to 16.
@@ -998,7 +838,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + bias[first_elt_read_by_thread + ii] : row_chunk[ii] * reciprocal_row_sum;
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
@@ -1047,20 +887,12 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
}
// Write the max for this k iteration to global memory.
T final_val = bias ? T(max_val) - bias[expert] : T(max_val);
if (thread_group_idx == 0) {
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
// input/output matrices.
const int idx = k * thread_row + k_idx;
if constexpr (Norm_Weights) {
const int idx_in_cta = k * thread_row_in_cta + k_idx;
row_output[idx_in_cta] = final_val;
weight_sum += final_val;
}
else {
output[idx] = final_val;
}
output[idx] = T(max_val);
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
@@ -1083,16 +915,6 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
}
}
}
if constexpr (Norm_Weights) {
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
if (thread_group_idx == 0) {
const int idx = k * thread_row + k_idx;
const int idx_in_cta = k * thread_row_in_cta + k_idx;
output[idx] = row_output[idx_in_cta] / weight_sum;
}
}
}
}
namespace detail {
@@ -1112,9 +934,8 @@ struct TopkConstants {
};
} // namespace detail
template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int>
void topk_gating_softmax_launcher_helper(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_row,
@@ -1132,10 +953,9 @@ void topk_gating_softmax_launcher_helper(const T* input,
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
input, bias, output, num_rows, indices, source_row, k);
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
<<<num_blocks, block_dim, 0, stream>>>(
input, output, num_rows, indices, source_row, k);
}
template <typename T, typename IdxT = int>
@@ -1166,7 +986,7 @@ static void run(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, gating_correction_bias, output, indices, source_row, num_rows, num_experts, k, stream); \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
@@ -1195,7 +1015,7 @@ static void run(const T* input,
group_experts,
softmax_num_rows);
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
group_moe_top_k<T, TPB>
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
output,
indices,
@@ -1296,18 +1116,6 @@ __global__ void initialize_moe_routing_kernel(
dest_vec[j] = static_cast<int8_t>(round(quant_value));
}
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec;
const float max_bound = 448.f;
const float min_bound = -448.f;
for (int j = 0; j < VecSize; j++) {
float quant_value = max_bound * scale * static_cast<float>(src_vec[j]);
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value);
}
Store<phi::dtype::float8_e4m3fn, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else {
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}

View File

@@ -113,20 +113,11 @@ void MoeDispatchKernel(
permuted_rows_, moe_topk * num_rows, false, stream);
if (w4a8_in_scale) {
if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
permuted_rows_, expert_idx_per_token->data<int32_t>(),
w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
hidden_size, moe_topk, stream);
}
} else {
initialize_moe_routing_kernelLauncher<data_t>::run(
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
@@ -144,7 +135,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
const bool group_moe, const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
int token_rows = 0;
@@ -160,14 +151,8 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];
auto permute_input_dtype = input_type;
if (w4a8_in_scale) {
if (moe_quant_type == "w4a8") {
permute_input_dtype = paddle::DataType::INT8;
} else if (moe_quant_type == "w4afp8") {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
}
}
auto permute_input_dtype =
w4a8_in_scale ? paddle::DataType::INT8 : input_type;
auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
permute_input_dtype, place);
@@ -300,7 +285,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Outputs({"permute_input", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -20,7 +20,6 @@
#include "helper.h"
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
template <paddle::DataType T>
void MoeFFNKernel(const paddle::Tensor& permute_input,
@@ -34,8 +33,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
bool used_in_ep_low_latency) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
@@ -62,22 +60,19 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
Allocator* allocator = paddle::GetAllocator(place);
Allocator::AllocationPtr workspace;
if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") {
if (quant_method == "weight_only_int4" || quant_method == "w4a8") {
inter_dim = inter_dim * 2;
}
if (quant_method == "w4a8" || quant_method == "w4afp8") {
if (quant_method == "w4a8") {
workspace = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * workspace_size);
}
const int64_t inter_size = inter_dim;
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 256;
int num_max_tokens_per_expert;
int expanded_active_expert_rows;
paddle::Tensor fc1_out_tensor;
@@ -166,49 +161,13 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
reinterpret_cast<NvType *>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
tune_total_rows,
inter_size,
hidden_size,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
Allocator::AllocationPtr ffn1_input_row_sum;
ffn1_input_row_sum = allocator->Allocate(
sizeof(float) * expanded_active_expert_rows);
compute_row_sum(
permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8 *>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8 *>(up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType *>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0],
num_experts,
inter_size,
hidden_size,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
@@ -235,6 +194,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
}
auto act_out = act_out_tensor.data<data_t>();
if (quant_method == "weight_only_int8") {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
@@ -307,73 +267,13 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
reinterpret_cast<NvType *>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows,
tune_total_rows,
hidden_size,
inter_size / 2,
reinterpret_cast<char*>(workspace->ptr()),
workspace_size,
num_experts,
stream);
} else if (quant_method == "w4afp8") {
data_t *ffn2_shift = nullptr;
data_t *ffn2_smooth = nullptr;
float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
ffn2_input_row_sum = allocator->Allocate(
sizeof(float) * expanded_active_expert_rows);
// note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
act_out_tensor.data<data_t>(),
stream
);
quantize_moe_input<data_t, data_t_fp8>(act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8 *>(fp8_act_out->ptr()),
stream
);
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8 *>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8 *>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0],
num_experts,
hidden_size,
inter_size / 2,
stream);
} else {
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
@@ -402,12 +302,10 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
const std::string& quant_method, const bool used_in_ep_low_latency) {
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
permute_input.dtype();
cudaCheckError();
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
switch (t_type) {
@@ -422,9 +320,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums);
ffn_out, used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
@@ -437,9 +333,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
down_proj_in_scale,
expert_idx_per_token,
quant_method,
ffn_out,
used_in_ep_low_latency,
estimate_total_token_nums);
ffn_out, used_in_ep_low_latency);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
@@ -457,8 +351,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
const std::string& quant_method, const bool used_in_ep_low_latency) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
@@ -468,9 +361,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
down_proj_scale,
down_proj_in_scale,
expert_idx_per_token,
quant_method,
used_in_ep_low_latency,
estimate_total_token_nums)};
quant_method, used_in_ep_low_latency)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
@@ -484,8 +375,7 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method,
const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
const bool used_in_ep_low_latency) {
return {permute_input_shape};
}
@@ -498,9 +388,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) {
if (quant_method == "w4a8" || quant_method == "w4afp8") {
const std::string &quant_method, const bool used_in_ep_low_latency) {
if (quant_method == "w4a8") {
return {up_gate_proj_scale_dtype.get()};
} else {
return {permute_input_dtype};
@@ -571,7 +460,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -51,7 +51,7 @@ void moe_redundant_topk_select_kernel(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
@@ -102,7 +102,7 @@ void moe_redundant_topk_select_kernel(const T* input,
else {
assert(k<=TPB);
if (apply_norm_weight) {
moe_softmax_top_k_fused<T, TPB, true>
moe_softmax_top_k_normed_fused<T, TPB>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
bias,
output,
@@ -112,7 +112,7 @@ void moe_redundant_topk_select_kernel(const T* input,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB, false>
moe_softmax_top_k_fused<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,

Some files were not shown because too many files have changed in this diff Show More