mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
Compare commits
400 Commits
release/2.
...
v2.2.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e42dc8c694 | ||
|
|
63a03ee152 | ||
|
|
9cc2c99539 | ||
|
|
31e32b5821 | ||
|
|
aebe12a58d | ||
|
|
8fdb950e9f | ||
|
|
a460462d2a | ||
|
|
cb8d87b945 | ||
|
|
de4feff147 | ||
|
|
f38b174a75 | ||
|
|
6b47773bd6 | ||
|
|
0358329946 | ||
|
|
01f6934162 | ||
|
|
7bdc6f41e5 | ||
|
|
bba279cf38 | ||
|
|
4f460db556 | ||
|
|
74d7b9151d | ||
|
|
0fa28b1068 | ||
|
|
cffde70949 | ||
|
|
7f9a9b37f3 | ||
|
|
b41988f4bc | ||
|
|
7ccbcc5a62 | ||
|
|
fbb4e0f8d1 | ||
|
|
4e8ba62241 | ||
|
|
7e3148ed81 | ||
|
|
4f8ff478b3 | ||
|
|
c4098d56a0 | ||
|
|
a6b161b007 | ||
|
|
7272afe3dc | ||
|
|
dfc94371ee | ||
|
|
35b8362804 | ||
|
|
d43c2f2577 | ||
|
|
14df2c59da | ||
|
|
934071578a | ||
|
|
36a58f487c | ||
|
|
d40a1046de | ||
|
|
fa2369271d | ||
|
|
8903f937f9 | ||
|
|
1023a67765 | ||
|
|
d43549953c | ||
|
|
c7c1627456 | ||
|
|
d6bf6de5e6 | ||
|
|
38e734e183 | ||
|
|
051e4a881c | ||
|
|
b2bb37d7c0 | ||
|
|
c6e2a37a95 | ||
|
|
8d77c1cb51 | ||
|
|
41cd3e24c9 | ||
|
|
11b18e5ef0 | ||
|
|
e2c764fd5a | ||
|
|
2d975e16b0 | ||
|
|
8915c8411d | ||
|
|
77c1bd0813 | ||
|
|
473cde779f | ||
|
|
335d1c8e8f | ||
|
|
173e4df982 | ||
|
|
199f88ce1e | ||
|
|
55ebe855c0 | ||
|
|
deb7ad205f | ||
|
|
e9f72df918 | ||
|
|
8567ada09e | ||
|
|
afcde19277 | ||
|
|
d40d3a5a4f | ||
|
|
b8d0f1c081 | ||
|
|
8550e19008 | ||
|
|
a0c03510c0 | ||
|
|
fb1e0d6a87 | ||
|
|
fbf0e9d2aa | ||
|
|
8c0e7d6fe9 | ||
|
|
b56b015d85 | ||
|
|
1432e336d7 | ||
|
|
9213a58a06 | ||
|
|
87ef0f5d30 | ||
|
|
abcd2148c0 | ||
|
|
05b6591c23 | ||
|
|
42402c80e9 | ||
|
|
1968c65849 | ||
|
|
37cb37b7f2 | ||
|
|
f975f7de2f | ||
|
|
174510180a | ||
|
|
5cda326ba2 | ||
|
|
a6c8f17431 | ||
|
|
cd09384a14 | ||
|
|
0f42771a84 | ||
|
|
d1d063e4af | ||
|
|
a86b35ab49 | ||
|
|
0cdbc950b5 | ||
|
|
2b0a745d57 | ||
|
|
1953c7c759 | ||
|
|
465065cd19 | ||
|
|
bed09ae8f8 | ||
|
|
753772ace8 | ||
|
|
98e03fb4ea | ||
|
|
fe5d09f9ee | ||
|
|
b9af95cf1c | ||
|
|
9a7c231f2c | ||
|
|
b21e085f3e | ||
|
|
7568b20098 | ||
|
|
455205f991 | ||
|
|
f206474cc7 | ||
|
|
c4b1f6b0a5 | ||
|
|
a18afcfdd9 | ||
|
|
cd252ec673 | ||
|
|
3754a9906d | ||
|
|
ccd52b5596 | ||
|
|
65425bf858 | ||
|
|
c71ee0831c | ||
|
|
f677c032c0 | ||
|
|
48d760539b | ||
|
|
45f81b34f0 | ||
|
|
1bf4fc7f36 | ||
|
|
68f87240da | ||
|
|
88297240e7 | ||
|
|
17b414c2df | ||
|
|
b6edd15d55 | ||
|
|
2fb2c0f46a | ||
|
|
43d5bd62b4 | ||
|
|
72094d4d82 | ||
|
|
73d60fe64d | ||
|
|
0b51b9c35b | ||
|
|
4957908275 | ||
|
|
02b3644903 | ||
|
|
808b548761 | ||
|
|
368bbd9dc6 | ||
|
|
fc635acc47 | ||
|
|
17731a8acd | ||
|
|
2a73a6df03 | ||
|
|
e93d4cfcdd | ||
|
|
94ded434bd | ||
|
|
e5015eea05 | ||
|
|
73cf6096da | ||
|
|
98c217b428 | ||
|
|
d4fc893fe3 | ||
|
|
c294fc8139 | ||
|
|
108d989d9d | ||
|
|
b791bea0c5 | ||
|
|
d37331fc71 | ||
|
|
ad9b95e6dd | ||
|
|
e81046fdad | ||
|
|
76513f6416 | ||
|
|
7afcd4b776 | ||
|
|
3d92fb09f7 | ||
|
|
479c8b85d3 | ||
|
|
e37e86b3b8 | ||
|
|
b28a0343a6 | ||
|
|
2974016103 | ||
|
|
836345a4dd | ||
|
|
11803e0907 | ||
|
|
c694fa2879 | ||
|
|
b2afdf4fc6 | ||
|
|
1265f6c192 | ||
|
|
f0140be1e1 | ||
|
|
e645db348b | ||
|
|
afb9f327ef | ||
|
|
5ad8721506 | ||
|
|
f8b70bf60c | ||
|
|
ce9c0917c5 | ||
|
|
ad319a87cc | ||
|
|
85afa72763 | ||
|
|
646a0c2fd8 | ||
|
|
f0a362af18 | ||
|
|
82e64b13e1 | ||
|
|
cbce94a00e | ||
|
|
642480f5f6 | ||
|
|
2f28f40d90 | ||
|
|
3200a80de3 | ||
|
|
00898603c8 | ||
|
|
9afa236e39 | ||
|
|
56e2d7e668 | ||
|
|
d339df2e90 | ||
|
|
52eda7fdb3 | ||
|
|
0a0d2959b9 | ||
|
|
75db0d1ae2 | ||
|
|
70c75798a7 | ||
|
|
0bc7d076fc | ||
|
|
a5b4866ff1 | ||
|
|
c68c3c4b8b | ||
|
|
c43a4bec00 | ||
|
|
66c5addce4 | ||
|
|
2fa173e327 | ||
|
|
2ae7ab28d2 | ||
|
|
c13c904971 | ||
|
|
9cab3f47ff | ||
|
|
2410adb041 | ||
|
|
9205c88da1 | ||
|
|
46664985fc | ||
|
|
7821534ff5 | ||
|
|
137e539456 | ||
|
|
bdbac0aa3d | ||
|
|
77514e3e1e | ||
|
|
93e1b63200 | ||
|
|
e481b7a779 | ||
|
|
79f0dbbb55 | ||
|
|
cb166053ba | ||
|
|
36325e9ea7 | ||
|
|
df7c31012b | ||
|
|
27666ee586 | ||
|
|
5b66462f0e | ||
|
|
7ae41e9daf | ||
|
|
76759108c9 | ||
|
|
cc88671507 | ||
|
|
2630260616 | ||
|
|
85fbf5455a | ||
|
|
3cc182236a | ||
|
|
c389a4013c | ||
|
|
e5aa7087db | ||
|
|
a5692e8b7d | ||
|
|
8bea4b1e25 | ||
|
|
e4f0b755b4 | ||
|
|
371fb3f853 | ||
|
|
466cbb5a99 | ||
|
|
b7eee3aec1 | ||
|
|
c83381d650 | ||
|
|
51f68ae593 | ||
|
|
985b1265c3 | ||
|
|
31f639f10b | ||
|
|
30b3f2dc07 | ||
|
|
bcdfc1d6b9 | ||
|
|
33ff0bfe38 | ||
|
|
e197894977 | ||
|
|
9ff2dfb162 | ||
|
|
33d369586b | ||
|
|
5d131485d8 | ||
|
|
3a6058e445 | ||
|
|
67298cf4c0 | ||
|
|
b047681c5d | ||
|
|
d587fb257f | ||
|
|
fef447e350 | ||
|
|
6735626014 | ||
|
|
bca8905b40 | ||
|
|
8b12c80f90 | ||
|
|
3a7a20d191 | ||
|
|
a053ab889b | ||
|
|
beec24fd89 | ||
|
|
c95b3395e9 | ||
|
|
32b39620bc | ||
|
|
2cf96ddd68 | ||
|
|
9c129813f9 | ||
|
|
70ee910cd5 | ||
|
|
ea4a3b479c | ||
|
|
5585cf7aa5 | ||
|
|
246cd7b3a5 | ||
|
|
6fdd83da10 | ||
|
|
a12d0bc549 | ||
|
|
3ee6053e5d | ||
|
|
e88f5552db | ||
|
|
33c0197ebe | ||
|
|
154308102e | ||
|
|
5703d7aa0f | ||
|
|
615930bc05 | ||
|
|
6f11171478 | ||
|
|
354575b6d1 | ||
|
|
cc8ee50f27 | ||
|
|
4bd6a9fa7d | ||
|
|
d4e3a20300 | ||
|
|
fbb6dcb9e4 | ||
|
|
562e01c979 | ||
|
|
cca96ab1e4 | ||
|
|
7132fa9ec2 | ||
|
|
6c1f3ff897 | ||
|
|
5a84324798 | ||
|
|
f0f00a6025 | ||
|
|
09c979f3dd | ||
|
|
ab60292f89 | ||
|
|
cacc52bf21 | ||
|
|
79d8ae4c38 | ||
|
|
1e06b9fa6d | ||
|
|
6031f9a5f5 | ||
|
|
f72db9386c | ||
|
|
7b596d0877 | ||
|
|
0ea8712018 | ||
|
|
2e7831185f | ||
|
|
666ab65a51 | ||
|
|
dd583fb16a | ||
|
|
d4f610e4cd | ||
|
|
396dba0d62 | ||
|
|
1ace375fc3 | ||
|
|
be94bdd0b0 | ||
|
|
f702a675a1 | ||
|
|
d1a92e3e17 | ||
|
|
ce9180241e | ||
|
|
b4fef2cf29 | ||
|
|
ed6bff215a | ||
|
|
8224b21525 | ||
|
|
eda83ca672 | ||
|
|
2d1a4cacdf | ||
|
|
2c0d853067 | ||
|
|
8791ad4e61 | ||
|
|
c575611a5b | ||
|
|
90bfa0be9c | ||
|
|
5620bd12de | ||
|
|
7d0d5a543a | ||
|
|
ccc7f1beb3 | ||
|
|
283da92bfa | ||
|
|
f5164215be | ||
|
|
b808c49585 | ||
|
|
b21272d9ff | ||
|
|
183e3863e8 | ||
|
|
19fda4e912 | ||
|
|
973ddad91e | ||
|
|
f27e879785 | ||
|
|
789dc67ff7 | ||
|
|
8bf96217b4 | ||
|
|
770b0aa3c5 | ||
|
|
9627619235 | ||
|
|
b23af29d0b | ||
|
|
c27a3dc43b | ||
|
|
c56c99837a | ||
|
|
9571c458f0 | ||
|
|
21caa63794 | ||
|
|
42af0b4b64 | ||
|
|
e0aeac58e1 | ||
|
|
b88537a456 | ||
|
|
71018fb62e | ||
|
|
0b77d396ad | ||
|
|
79868be220 | ||
|
|
46c8491201 | ||
|
|
566badb83c | ||
|
|
eaae4a580d | ||
|
|
c011cb8b16 | ||
|
|
1e4968e810 | ||
|
|
31d4fcb425 | ||
|
|
22255a65aa | ||
|
|
a799d14df1 | ||
|
|
ce1f353c70 | ||
|
|
d0e9a70380 | ||
|
|
71267840f7 | ||
|
|
b76b17fc1b | ||
|
|
fac2f64837 | ||
|
|
fbdd6b0663 | ||
|
|
37569cca86 | ||
|
|
5f0b30f6d0 | ||
|
|
6037dd5d9c | ||
|
|
9423c577fe | ||
|
|
5885285e57 | ||
|
|
55ac449c31 | ||
|
|
820798aec5 | ||
|
|
0074b423a9 | ||
|
|
93a1731891 | ||
|
|
09cc4e2802 | ||
|
|
d9e3f88f9e | ||
|
|
9408e667a5 | ||
|
|
3a15e0c53e | ||
|
|
afff4d37ea | ||
|
|
20839abccf | ||
|
|
91dc87f1c5 | ||
|
|
256a82b0b3 | ||
|
|
36dc73470d | ||
|
|
a6e8b780f8 | ||
|
|
89397516a8 | ||
|
|
841e831575 | ||
|
|
e0bbd3b6ca | ||
|
|
7ce00e597c | ||
|
|
4a10e29804 | ||
|
|
af543b7f0f | ||
|
|
e24929efa3 | ||
|
|
b01cfd6007 | ||
|
|
55939f7942 | ||
|
|
04fc7eb931 | ||
|
|
9f1936ae28 | ||
|
|
1e9a8e8cef | ||
|
|
f5c64a074c | ||
|
|
14ed75f7d3 | ||
|
|
40f7f3e0d8 | ||
|
|
b8f3c73aac | ||
|
|
fb7a0689cc | ||
|
|
c593e1a39c | ||
|
|
e39159f3bd | ||
|
|
88596c0c63 | ||
|
|
fe540f6caa | ||
|
|
72ef5a9c93 | ||
|
|
1f8289e106 | ||
|
|
3eb9a5df60 | ||
|
|
68bc1d12c0 | ||
|
|
01d7586661 | ||
|
|
2bd8a50649 | ||
|
|
0443587a57 | ||
|
|
17f51f0c92 | ||
|
|
79bbacc152 | ||
|
|
3bfb2eca92 | ||
|
|
c9e6ce1518 | ||
|
|
4021d66ea5 | ||
|
|
1582814905 | ||
|
|
66d3bb89ad | ||
|
|
22fe695f1c | ||
|
|
b71cbb466d | ||
|
|
243394044d | ||
|
|
0eb32bb9c8 | ||
|
|
64d7a3194d | ||
|
|
bdb83e007d | ||
|
|
50db0d7ba9 | ||
|
|
94264bbf60 | ||
|
|
3a4db15765 | ||
|
|
c34088b0fd | ||
|
|
fc5f43c6bc | ||
|
|
a2f5cc54f8 | ||
|
|
1d93565082 | ||
|
|
e1011e92d9 | ||
|
|
8c63237cfa | ||
|
|
ff6a109b4d |
186
.github/workflows/_accuracy_test.yml
vendored
Normal file
186
.github/workflows/_accuracy_test.yml
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
name: Accuracy Test
|
||||
description: "Run Accuracy Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
accuracy_tests:
|
||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run FastDeploy Base Tests
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
python -m pip install pytest
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
chmod +x ./llm-deploy-linux-amd64
|
||||
./llm-deploy-linux-amd64 -python python3.10 \
|
||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||
-model_path /MODELDATA \
|
||||
--skip install
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
|
||||
popd
|
||||
|
||||
pushd tests/ce/accuracy_cases
|
||||
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
|
||||
export TEMPLATE=TOKEN_LOGPROB
|
||||
export MODEL_SIZE=0.3B
|
||||
TEST_EXIT_CODE=0
|
||||
python gsm8k.py || TEST_EXIT_CODE=1
|
||||
popd
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
|
||||
'
|
||||
if [ -f ./FastDeploy/exit_code.env ]; then
|
||||
source ./FastDeploy/exit_code.env
|
||||
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
|
||||
exit ${TEST_EXIT_CODE}
|
||||
229
.github/workflows/_base_test.yml
vendored
Normal file
229
.github/workflows/_base_test.yml
vendored
Normal file
@@ -0,0 +1,229 @@
|
||||
name: Base Test
|
||||
description: "Run Base Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
base_tests:
|
||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run FastDeploy Base Tests
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
python -m pip install pytest
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
chmod +x ./llm-deploy-linux-amd64
|
||||
./llm-deploy-linux-amd64 -python python3.10 \
|
||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||
-model_path /MODELDATA \
|
||||
--skip install
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
check_service() {
|
||||
local timeout=${1:-90}
|
||||
local url="http://localhost:${FLASK_PORT}/wait_for_infer?timeout=${timeout}"
|
||||
local resp
|
||||
|
||||
resp=$(curl -s -X POST "$url")
|
||||
|
||||
if echo "$resp" | grep -q "服务启动超时"; then
|
||||
exit 8
|
||||
fi
|
||||
}
|
||||
|
||||
check_service 90
|
||||
popd
|
||||
|
||||
pushd tests/ce/server
|
||||
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
|
||||
export TEMPLATE=TOKEN_LOGPROB
|
||||
TEST_EXIT_CODE=0
|
||||
python -m pytest -sv test_base_chat.py test_compare_top_logprobs.py test_logprobs.py test_params_boundary.py test_seed_usage.py test_stream.py test_evil_cases.py test_completions.py test_return_token_ids.py || TEST_EXIT_CODE=1
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--early-stop-config\": \"{\\\"enable_early_stop\\\":true, \\\"window_size\\\":6, \\\"threshold\\\":0.93}\"}"
|
||||
check_service 90
|
||||
python -m pytest -sv test_repetition_early_stop.py || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }"
|
||||
check_service 90
|
||||
python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }"
|
||||
check_service 90
|
||||
python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_mtp.yaml\", \"--enable-logprob\": \"False\"}"
|
||||
check_service 180
|
||||
export TEMPLATE=TOKEN_NORMAL
|
||||
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_sot.yaml\", \"--enable-logprob\": \"False\"}"
|
||||
check_service 360
|
||||
export TEMPLATE=TOKEN_NORMAL
|
||||
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
|
||||
|
||||
popd
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
|
||||
'
|
||||
if [ -f ./FastDeploy/exit_code.env ]; then
|
||||
source ./FastDeploy/exit_code.env
|
||||
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
|
||||
exit ${TEST_EXIT_CODE}
|
||||
35
.github/workflows/_build_linux.yml
vendored
35
.github/workflows/_build_linux.yml
vendored
@@ -22,12 +22,22 @@ on:
|
||||
description: "Enable nightly build mode (e.g. add date suffix to version)"
|
||||
required: false
|
||||
type: string
|
||||
default: "ON"
|
||||
default: "OFF"
|
||||
FD_VERSION:
|
||||
description: "FastDeploy Package Version"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
PADDLEVERSION:
|
||||
description: "Paddle Version Build Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
PADDLE_WHL_URL:
|
||||
description: "Paddle Wheel Package URL"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
UPLOAD:
|
||||
description: "Upload Package"
|
||||
required: false
|
||||
@@ -45,6 +55,7 @@ on:
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-Build]
|
||||
timeout-minutes: 240
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
@@ -85,6 +96,10 @@ jobs:
|
||||
compile_arch: ${{ inputs.COMPILE_ARCH }}
|
||||
fd_version: ${{ inputs.FD_VERSION }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
BRANCH_REF: ${{ github.ref_name }}
|
||||
PADDLEVERSION: ${{ inputs.PADDLEVERSION }}
|
||||
PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }}
|
||||
WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }}
|
||||
run: |
|
||||
set -x
|
||||
runner_name="${{ runner.name }}"
|
||||
@@ -109,6 +124,9 @@ jobs:
|
||||
-e "COMPILE_ARCH=${compile_arch}" \
|
||||
-e "FD_VERSION=${fd_version}" \
|
||||
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
|
||||
-e "PADDLEVERSION=${PADDLEVERSION}" \
|
||||
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
|
||||
-e "BRANCH_REF=${BRANCH_REF}" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
if [[ -n "${FD_VERSION}" ]]; then
|
||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||
@@ -116,6 +134,7 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
@@ -124,14 +143,20 @@ jobs:
|
||||
echo "Date Only: $DATE_ONLY"
|
||||
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
|
||||
fi
|
||||
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
|
||||
pip config set install.trusted-host pip.baidu.com
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
# 针对不同分支和tag使用不同的PaddlePaddle安装包
|
||||
if [[ "${PADDLE_WHL_URL}" != "" ]];then
|
||||
python -m pip install ${PADDLE_WHL_URL}
|
||||
elif [[ "${PADDLEVERSION}" != "" ]];then
|
||||
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
else
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
fi
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install wheel
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
# 编译RDMA
|
||||
export ENABLE_FD_RDMA=1
|
||||
bash build.sh 1 python false [${COMPILE_ARCH}]
|
||||
|
||||
2
.github/workflows/_clone_linux.yml
vendored
2
.github/workflows/_clone_linux.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
wget -O bos_tools.py -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
|
||||
72
.github/workflows/_logprob_test_linux.yml
vendored
72
.github/workflows/_logprob_test_linux.yml
vendored
@@ -62,18 +62,24 @@ jobs:
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
if [[ "$last_char" =~ [0-7] ]]; then
|
||||
DEVICES="$last_char"
|
||||
else
|
||||
DEVICES="0"
|
||||
fi
|
||||
|
||||
FLASK_PORT=$((9160 + DEVICES * 100))
|
||||
FD_API_PORT=$((9180 + DEVICES * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + DEVICES * 100))
|
||||
FD_METRICS_PORT=$((9170 + DEVICES * 100))
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
@@ -85,28 +91,52 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
docker run --ipc=host --pid=host --net=host \
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c '
|
||||
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
|
||||
pip config set install.trusted-host pip.baidu.com
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
@@ -124,6 +154,10 @@ jobs:
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
|
||||
curl -s -o /dev/null -w "%{http_code}" -m 2 "http://0.0.0.0:${FD_API_PORT}/health"
|
||||
curl -X POST "http://0.0.0.0:${FD_API_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}"
|
||||
set +e
|
||||
rm -rf ./baseline_output
|
||||
cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output
|
||||
|
||||
148
.github/workflows/_pre_ce_test.yml
vendored
Normal file
148
.github/workflows/_pre_ce_test.yml
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
name: Pre-CE-Test
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
run_ce_cases:
|
||||
runs-on: [self-hosted, PRE_CE_RUN_2Card]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install ${fd_wheel_url}
|
||||
bash scripts/run_pre_ce.sh
|
||||
'
|
||||
170
.github/workflows/_stable_test.yml
vendored
Normal file
170
.github/workflows/_stable_test.yml
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
name: Stable Test
|
||||
description: "Run Stable Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
stable_tests:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run FastDeploy Stable Tests
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42038 + DEVICE_PORT * 100))
|
||||
FD_INFERENCE_MSG_QUEUE_ID=$(( 42048 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
python -m pip install pytest
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
TEST_EXIT_CODE=0
|
||||
pushd tests/ce/stable_cases
|
||||
bash launch_model.sh /MODELDATA
|
||||
bash run.sh || TEST_EXIT_CODE=1
|
||||
popd
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
|
||||
'
|
||||
if [ -f ./FastDeploy/exit_code.env ]; then
|
||||
source ./FastDeploy/exit_code.env
|
||||
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
|
||||
exit ${TEST_EXIT_CODE}
|
||||
259
.github/workflows/_unit_test_coverage.yml
vendored
259
.github/workflows/_unit_test_coverage.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
name: Coverage Check
|
||||
description: "Run FastDeploy Unit Tests and Coverage"
|
||||
|
||||
on:
|
||||
@@ -22,13 +22,32 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
secrets:
|
||||
github-token:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
check_cov_skip:
|
||||
uses: ./.github/workflows/check-bypass.yml
|
||||
secrets:
|
||||
github-token: ${{ secrets.github-token }}
|
||||
with:
|
||||
workflow-name: coverage
|
||||
|
||||
run_tests_with_coverage:
|
||||
runs-on: [self-hosted, GPU-h1z1-4Cards]
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 60
|
||||
needs: check_cov_skip
|
||||
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
||||
outputs:
|
||||
diff_cov_file_url: ${{ steps.cov_upload.outputs.diff_cov_file_url }}
|
||||
unittest_failed_url: ${{ steps.unittest_failed.outputs.unittest_failed_url }}
|
||||
unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }}
|
||||
diff_cov_result_json_url: ${{ steps.cov_upload.outputs.diff_cov_result_json_url }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
@@ -66,58 +85,128 @@ jobs:
|
||||
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
IS_PR: ${{ github.event_name == 'pull_request' }}
|
||||
run: |
|
||||
set -x
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
if [[ "$IS_PR" == "true" ]]; then
|
||||
echo "Running on PR"
|
||||
else
|
||||
echo "Not a PR"
|
||||
fi
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host \
|
||||
--cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
-e "BASE_REF=${BASE_REF}" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
|
||||
pip config set install.trusted-host pip.baidu.com
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install ${fd_wheel_url}
|
||||
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
|
||||
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
|
||||
TEST_EXIT_CODE=0
|
||||
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
|
||||
coverage combine coveragedata/
|
||||
coverage xml -o python_coverage_all.xml
|
||||
COVERAGE_EXIT_CODE=0
|
||||
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=90 || COVERAGE_EXIT_CODE=9
|
||||
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --net=host \
|
||||
--name ${runner_name} \
|
||||
--cap-add=SYS_PTRACE --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
-e "BASE_REF=${BASE_REF}" \
|
||||
-e "IS_PR=${IS_PR}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install coverage
|
||||
python -m pip install diff-cover
|
||||
python -m pip install pytest-cov
|
||||
python -m pip install jsonschema aistudio_sdk==0.3.5
|
||||
python -m pip install ${fd_wheel_url}
|
||||
rm -rf fastdeploy
|
||||
# coverage subprocess use
|
||||
python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy
|
||||
export PYTHONPATH=/workspace/FastDeploy/
|
||||
if [ -d "tests/plugins" ]; then
|
||||
cd tests/plugins
|
||||
python setup.py install
|
||||
cd ../..
|
||||
else
|
||||
echo "Warning: tests/plugins directory not found, skipping setup.py install"
|
||||
fi
|
||||
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
|
||||
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
|
||||
TEST_EXIT_CODE=0
|
||||
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
|
||||
coverage combine coveragedata/ || echo "No data to combine"
|
||||
coverage report
|
||||
coverage xml -o python_coverage_all.xml
|
||||
COVERAGE_EXIT_CODE=0
|
||||
if [[ "$IS_PR" == "true" ]]; then
|
||||
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
|
||||
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
|
||||
'
|
||||
if [ -f FastDeploy/exit_code.env ]; then
|
||||
cat FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
else
|
||||
echo "Not a PR, skipping diff-cover"
|
||||
fi
|
||||
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
|
||||
'
|
||||
if [ -f FastDeploy/exit_code.env ]; then
|
||||
cat FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Upload unit resule and diff coverage to bos
|
||||
id: cov_upload
|
||||
shell: bash
|
||||
@@ -125,42 +214,97 @@ jobs:
|
||||
cd FastDeploy
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}/CoverageData
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py -O bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
diff_cov_file="diff_coverage.xml"
|
||||
if [ -f ${diff_cov_file} ];then
|
||||
python ${push_file} ${diff_cov_file} ${target_path}
|
||||
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${diff_cov_file}
|
||||
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
|
||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
|
||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Determine Unit Succ and whether the coverage rate reaches 90%
|
||||
diff_cov_result_json="diff_coverage.json"
|
||||
if [ -f ${diff_cov_result_json} ];then
|
||||
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
|
||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
|
||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
unittest_result="failed_tests.log"
|
||||
if [ -s ${unittest_result} ];then
|
||||
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
|
||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
|
||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Check Unit Test Success
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
if [ "$TEST_EXIT_CODE" -eq 8 ]; then
|
||||
filename=$(basename "$unittest_failed_url")
|
||||
if [ -z "${unittest_failed_url}" ]; then
|
||||
echo "No diff unit failed file URL provided."
|
||||
else
|
||||
rm -rf "${filename}"
|
||||
wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
|
||||
fi
|
||||
echo "Unit tests failed (exit code 8)"
|
||||
if [ -f "${filename}" ];then
|
||||
echo "Failed test cases:"
|
||||
cat "${filename}"
|
||||
fi
|
||||
exit "$TEST_EXIT_CODE"
|
||||
fi
|
||||
echo "All tests passed"
|
||||
|
||||
- name: Verify Code Coverage Threshold (80%)
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then
|
||||
echo "Coverage generation failed (exit code 9)"
|
||||
filename=$(basename "$diff_cov_result_json_url")
|
||||
if [ -z "${diff_cov_result_json_url}" ]; then
|
||||
echo "No diff cov result file URL provided."
|
||||
else
|
||||
rm -rf "${filename}"
|
||||
wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
|
||||
fi
|
||||
if [ -f "${filename}" ];then
|
||||
echo "Failed test cases:"
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
jq . "${filename}"
|
||||
else
|
||||
cat "${filename}"
|
||||
fi
|
||||
fi
|
||||
exit "$COVERAGE_EXIT_CODE"
|
||||
fi
|
||||
echo "All tests and coverage passed"
|
||||
echo "coverage passed"
|
||||
exit 0
|
||||
|
||||
diff_coverage_report:
|
||||
needs: run_tests_with_coverage
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
steps:
|
||||
- name: coverage diff file download
|
||||
shell: bash
|
||||
env:
|
||||
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
|
||||
run: |
|
||||
wget ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
if [ -z "${diff_cov_file_url}" ]; then
|
||||
echo "No diff coverage file URL provided."
|
||||
exit 0
|
||||
@@ -170,6 +314,9 @@ jobs:
|
||||
if: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url != null && needs.run_tests_with_coverage.outputs.diff_cov_file_url != '' }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./diff_coverage.xml
|
||||
files: ./FastDeploy/diff_coverage.xml
|
||||
name: python diff coverage
|
||||
verbose: true
|
||||
disable_search: true
|
||||
commit_parent: false
|
||||
flags: diff
|
||||
|
||||
4
.github/workflows/approve.yml
vendored
4
.github/workflows/approve.yml
vendored
@@ -6,6 +6,9 @@ on:
|
||||
- develop
|
||||
- 'release/*'
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
Approval:
|
||||
name: Approval
|
||||
@@ -33,7 +36,6 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Run approval check script
|
||||
run: |
|
||||
|
||||
248
.github/workflows/ce_job.yml
vendored
Normal file
248
.github/workflows/ce_job.yml
vendored
Normal file
@@ -0,0 +1,248 @@
|
||||
name: CE Compile Job
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ce_job_pre_check:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }}
|
||||
CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }}
|
||||
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
|
||||
outputs:
|
||||
branch_match: ${{ steps.set_output.outputs.branch_match }}
|
||||
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
|
||||
sm8689_match: ${{ steps.set_output.outputs.sm8689_match }}
|
||||
sm8090_match: ${{ steps.set_output.outputs.sm8090_match }}
|
||||
|
||||
steps:
|
||||
- name: Set Version
|
||||
id: set_output
|
||||
env:
|
||||
COMPILE_BRANCH: ${{ env.COMPILE_BRANCH }}
|
||||
CE_COMPILE_SELECTION: ${{ env.CE_COMPILE_SELECTION }}
|
||||
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ env.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
|
||||
GITHUB_REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
# 选择要触发编译任务的分支 done
|
||||
# 选择指定分支要编译的任务 8090或者8689
|
||||
# 指定分支编译要使用的Paddle的安装包,默认使用nightly最新的
|
||||
|
||||
IFS=',' read -ra BRANCHES <<< "$COMPILE_BRANCH"
|
||||
MATCH=false
|
||||
for b in "${BRANCHES[@]}"; do
|
||||
if [[ "$b" == "${GITHUB_REF_NAME}" ]]; then
|
||||
MATCH=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
echo "branch_match=$MATCH" >> $GITHUB_OUTPUT
|
||||
|
||||
# 通过变量CE_COMPILE_SELECTION中的映射关系,决定分支是编译sm8090还是sm8689
|
||||
for pair in $(echo "$CE_COMPILE_SELECTION" | tr ';' ' '); do
|
||||
branch=$(echo "$pair" | cut -d',' -f1)
|
||||
compile_task_list=$(echo "$pair" | cut -d',' -f2)
|
||||
|
||||
if [[ "$branch" == "$GITHUB_REF_NAME" ]]; then
|
||||
|
||||
# 判断里面是否包含 sm8090 或 sm8689
|
||||
if [[ "$compile_task_list" == *"sm8090"* ]]; then
|
||||
echo "sm8090_match=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "$compile_task_list" == *"sm8689"* ]]; then
|
||||
echo "sm8689_match=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
|
||||
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
|
||||
branch=$(echo "$pair" | cut -d',' -f1)
|
||||
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
|
||||
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
|
||||
FOUND_PADDLE_URL="$paddle_whl_url"
|
||||
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
print_ce_job_pre_check_outputs:
|
||||
runs-on: ubuntu-latest
|
||||
needs: ce_job_pre_check
|
||||
steps:
|
||||
- name: Print outputs as JSON
|
||||
run: |
|
||||
echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}'
|
||||
|
||||
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
needs: ce_job_pre_check
|
||||
if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }}
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request'
|
||||
&& github.event.pull_request.base.ref
|
||||
|| github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, ce_job_pre_check]
|
||||
if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }}
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
WITH_NIGHTLY_BUILD: OFF
|
||||
FD_VERSION: 0.0.0
|
||||
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
build_sm8689:
|
||||
name: BUILD_SM8689
|
||||
needs: [clone, ce_job_pre_check]
|
||||
if: ${{ needs.ce_job_pre_check.outputs.sm8689_match == 'true' }}
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
WITH_NIGHTLY_BUILD: OFF
|
||||
FD_VERSION: 0.0.0
|
||||
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
ce_upload_sm8090:
|
||||
environment: CodeSync
|
||||
name: CE_UPLOAD
|
||||
needs: build_sm8090
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
||||
ce_upload_sm8689:
|
||||
environment: CodeSync
|
||||
name: CE_UPLOAD
|
||||
needs: build_sm8689
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
51
.github/workflows/check-bypass.yml
vendored
Normal file
51
.github/workflows/check-bypass.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
workflow-name:
|
||||
required: true
|
||||
type: string
|
||||
secrets:
|
||||
github-token:
|
||||
required: true
|
||||
outputs:
|
||||
can-skip:
|
||||
description: "Whether the workflow can be skipped."
|
||||
value: ${{ jobs.check-bypass.outputs.can-skip }}
|
||||
|
||||
jobs:
|
||||
check-bypass:
|
||||
name: Check bypass
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
|
||||
outputs:
|
||||
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
|
||||
steps:
|
||||
- name: Cleanup
|
||||
run: |
|
||||
rm -rf * .[^.]*
|
||||
|
||||
- id: check-bypass
|
||||
name: Check Bypass
|
||||
uses: PFCCLab/ci-bypass@v1
|
||||
with:
|
||||
github-token: ${{ secrets.github-token }}
|
||||
non-pull-request-event-strategy: 'never-skipped'
|
||||
type: 'composite'
|
||||
composite-rule: |
|
||||
{
|
||||
"any": [
|
||||
{
|
||||
"type": "labeled",
|
||||
"label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"],
|
||||
"username": ${{ env.CI_TEAM_MEMBERS }}
|
||||
},
|
||||
{
|
||||
"type": "commented",
|
||||
"comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"],
|
||||
"username": ${{ env.CI_TEAM_MEMBERS }}
|
||||
}
|
||||
]
|
||||
}
|
||||
89
.github/workflows/ci.yml
vendored
89
.github/workflows/ci.yml
vendored
@@ -1,89 +0,0 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, GPU-L20-4Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [ "${last_char}" = "1" ]; then
|
||||
gpu_id=2
|
||||
DEVICES="2,3"
|
||||
else
|
||||
gpu_id=0
|
||||
DEVICES="0,1"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "/ssd4/GithubActions/ModelData:/ModelData:ro" \
|
||||
-v "/ssd4/GithubActions/CacheDir:/root/.cache" \
|
||||
-v "/ssd4/GithubActions/ConfigDir:/root/.config" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci.sh
|
||||
"
|
||||
27
.github/workflows/ci_gcu.yml
vendored
27
.github/workflows/ci_gcu.yml
vendored
@@ -13,7 +13,8 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
CI_GCU:
|
||||
runs-on: [self-hosted, GCU-S60-8Card]
|
||||
runs-on:
|
||||
group: GCU
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
@@ -28,7 +29,9 @@ jobs:
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
docker run --rm --net=host -v $(pwd):/workspace \
|
||||
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
|
||||
-w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
@@ -39,6 +42,7 @@ jobs:
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
source ${{ github.workspace }}/../../../proxy
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
@@ -49,6 +53,9 @@ jobs:
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
echo "Copy models..."
|
||||
sudo mkdir -p ci_models && sudo cp -r /work/deps/ERNIE-4.5-21B-A3B-Paddle ci_models
|
||||
echo "Copy models done."
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
@@ -70,19 +77,21 @@ jobs:
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
echo "Install drivers..."
|
||||
cd /work/deps
|
||||
bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
|
||||
sudo bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
|
||||
cd -
|
||||
docker run --rm --network=host --ipc=host -it --privileged \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/home:/home" \
|
||||
-v "/work:/work" \
|
||||
-e "MODEL_PATH=/work/models" \
|
||||
echo "Create docker..."
|
||||
docker run --rm --network=host --ipc=host --privileged \
|
||||
-v $(pwd):/workspace \
|
||||
-v /home:/home \
|
||||
-v /work:/work \
|
||||
-w /workspace \
|
||||
-e "MODEL_PATH=./ci_models" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_gcu.sh
|
||||
|
||||
3
.github/workflows/ci_iluvatar.yml
vendored
3
.github/workflows/ci_iluvatar.yml
vendored
@@ -11,7 +11,8 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
CI_ILUVATAR:
|
||||
runs-on: [self-hosted, IXUCA]
|
||||
runs-on:
|
||||
group: IXUCA
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
|
||||
5
.github/workflows/ci_xpu.yml
vendored
5
.github/workflows/ci_xpu.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
@@ -77,6 +77,7 @@ jobs:
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
|
||||
2
.github/workflows/gh-pages.yml
vendored
2
.github/workflows/gh-pages.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang
|
||||
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
53
.github/workflows/pr_build_and_test.yml
vendored
53
.github/workflows/pr_build_and_test.yml
vendored
@@ -19,9 +19,9 @@ jobs:
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
COMPILE_ARCH: "89,90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
FD_VERSION: "0.0.0"
|
||||
|
||||
@@ -39,16 +39,59 @@ jobs:
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelCache"
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
321
.github/workflows/publish_job.yml
vendored
Normal file
321
.github/workflows/publish_job.yml
vendored
Normal file
@@ -0,0 +1,321 @@
|
||||
name: Publish Job
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
||||
push:
|
||||
# branches:
|
||||
# - develop
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
publish_pre_check:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.repository.fork == false &&
|
||||
(
|
||||
(github.event_name == 'schedule' && github.ref_name == 'develop') ||
|
||||
(github.event_name == 'push' && github.ref_type == 'tag') ||
|
||||
((github.event_name == 'workflow_dispatch') &&
|
||||
(github.ref_name == 'develop' || github.ref_type == 'tag'))
|
||||
)
|
||||
env:
|
||||
TAG_VERSION_MAPPINGS: ${{ vars.TAG_VERSION_MAPPINGS }}
|
||||
FD_VERSION_DEV: ${{ vars.FD_VERSION_DEV }}
|
||||
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
|
||||
outputs:
|
||||
compile_use_paddle_version: ${{ steps.set_output.outputs.compile_use_paddle_version }}
|
||||
compile_continue: ${{ steps.set_output.outputs.compile_continue }}
|
||||
fd_version: ${{ steps.set_output.outputs.fd_version }}
|
||||
with_nightly_build: ${{ steps.set_output.outputs.with_nightly_build }}
|
||||
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
steps:
|
||||
- name: Get tag version
|
||||
if: github.ref_type == 'tag'
|
||||
run: |
|
||||
TAG_NAME="${GITHUB_REF##*/}" # 提取 tag 名称,比如 v2.1.0
|
||||
TAG_VERSION="${TAG_NAME#v}" # 去掉前缀 v
|
||||
echo "FD_VERSION=$TAG_VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Check FD version to Paddle version mapping
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
TARGET_FD: ${{ env.FD_VERSION }}
|
||||
run: |
|
||||
FOUND_PADDLE=""
|
||||
# 遍历映射
|
||||
for pair in $(echo $TAG_VERSION_MAPPINGS | tr ';' ' '); do
|
||||
fd=$(echo "$pair" | cut -d',' -f1)
|
||||
paddle=$(echo "$pair" | cut -d',' -f2)
|
||||
if [[ "$fd" == "$TARGET_FD" ]]; then
|
||||
FOUND_PADDLE="$paddle"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ -z "$FOUND_PADDLE" ]]; then
|
||||
echo "No Paddle version found for FD $TARGET_FD"
|
||||
else
|
||||
echo "FD $TARGET_FD maps to Paddle $FOUND_PADDLE"
|
||||
echo "PADDLE_VERSION=$FOUND_PADDLE" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Set Version
|
||||
id: set_output
|
||||
env:
|
||||
PADDLE_VERSION: ${{ env.PADDLE_VERSION }}
|
||||
FD_VERSION: ${{ env.FD_VERSION }}
|
||||
run: |
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
if [[ -z "$PADDLE_VERSION" ]]; then
|
||||
compile_continue=false
|
||||
else
|
||||
compile_use_paddle_version=$PADDLE_VERSION
|
||||
compile_continue=true
|
||||
fi
|
||||
fd_version=$FD_VERSION
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" == "develop" ]];then
|
||||
compile_continue=true
|
||||
compile_use_paddle_version=""
|
||||
fd_version=${FD_VERSION_DEV}
|
||||
with_nightly_build=ON
|
||||
fi
|
||||
# Todo
|
||||
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
|
||||
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
|
||||
branch=$(echo "$pair" | cut -d',' -f1)
|
||||
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
|
||||
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
|
||||
FOUND_PADDLE_URL="$paddle_whl_url"
|
||||
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
|
||||
compile_continue=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
echo "compile_continue=${compile_continue}" >> $GITHUB_OUTPUT
|
||||
echo "compile_use_paddle_version=${compile_use_paddle_version}" >> $GITHUB_OUTPUT
|
||||
echo "fd_version=${fd_version}" >> $GITHUB_OUTPUT
|
||||
echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT
|
||||
|
||||
print_publish_pre_check_outputs:
|
||||
runs-on: ubuntu-latest
|
||||
needs: publish_pre_check
|
||||
steps:
|
||||
- name: Print outputs as JSON
|
||||
run: |
|
||||
echo '${{ toJSON(needs.publish_pre_check.outputs) }}'
|
||||
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
needs: publish_pre_check
|
||||
if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }}
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, publish_pre_check]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
build_sm8689:
|
||||
name: BUILD_SM8689
|
||||
needs: [clone, publish_pre_check]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
paddle_pypi_upload_sm8090:
|
||||
environment: PaddleSourceUpload
|
||||
name: PADDLE_PYPI_UPLOAD_8090
|
||||
needs: build_sm8090
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
if: github.ref_name == 'develop' || github.ref_type == 'tag'
|
||||
run: |
|
||||
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
|
||||
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
|
||||
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
|
||||
if [[ "${{ github.ref_name }}" == "develop" ]];then
|
||||
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
else
|
||||
echo "Not develop or tag, do nothing"
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
|
||||
paddle_pypi_upload_sm8689:
|
||||
environment: PaddleSourceUpload
|
||||
name: PADDLE_PYPI_UPLOAD_8689
|
||||
needs: build_sm8689
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
if: github.ref_name == 'develop' || github.ref_type == 'tag'
|
||||
run: |
|
||||
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
|
||||
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
|
||||
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
|
||||
if [[ "${{ github.ref_name }}" == "develop" ]];then
|
||||
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
else
|
||||
echo "Not develop or tag, do nothing"
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build_sm8090]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -121,7 +121,7 @@ dmypy.json
|
||||
FETCH_HEAD
|
||||
|
||||
#log
|
||||
log*/
|
||||
log/
|
||||
|
||||
checkpoints/
|
||||
checkpoints_origin/
|
||||
@@ -156,6 +156,12 @@ nohup.out
|
||||
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass
|
||||
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
|
||||
|
||||
#marlin_kernel
|
||||
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
|
||||
|
||||
#machete_kernel
|
||||
custom_ops/gpu_ops/machete/generated
|
||||
|
||||
# buff
|
||||
custom_ops/tmp*
|
||||
|
||||
@@ -164,3 +170,9 @@ build
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
|
||||
|
||||
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
|
||||
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h
|
||||
|
||||
9
.gitmodules
vendored
Normal file
9
.gitmodules
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
||||
25
README.md
25
README.md
@@ -1,3 +1,4 @@
|
||||
English | [简体中文](README_CN.md)
|
||||
<p align="center">
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
|
||||
</p>
|
||||
@@ -22,11 +23,12 @@
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
|
||||
## News
|
||||
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
|
||||
|
||||
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
@@ -50,14 +52,16 @@
|
||||
|
||||
## Installation
|
||||
|
||||
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
|
||||
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions:
|
||||
|
||||
- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md)
|
||||
- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md)
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
## Get Started
|
||||
|
||||
@@ -67,19 +71,12 @@ Learn how to use FastDeploy through our documentation:
|
||||
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
|
||||
- [Offline Inference Development](./docs/offline_inference.md)
|
||||
- [Online Service Deployment](./docs/online_serving/README.md)
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
- [Best Practices](./docs/best_practices/README.md)
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
Learn how to download models, enable using the torch format, and more:
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
|
||||
89
README_CN.md
Normal file
89
README_CN.md
Normal file
@@ -0,0 +1,89 @@
|
||||
[English](README.md) | 简体中文
|
||||
<p align="center">
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href=""><img src="https://img.shields.io/badge/python-3.10-aff.svg"></a>
|
||||
<a href=""><img src="https://img.shields.io/badge/os-linux-pink.svg"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/FastDeploy?color=9ea"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/"><b> 安装指导 </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/quick_start"><b> 快速入门 </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/supported_models/"><b> 支持模型列表 </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
|
||||
|
||||
## 最新活动
|
||||
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容,性能进一步优化,更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
|
||||
|
||||
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
## 关于
|
||||
|
||||
**FastDeploy** 是基于飞桨(PaddlePaddle)的大语言模型(LLM)与视觉语言模型(VLM)推理部署工具包,提供**开箱即用的生产级部署方案**,核心技术特性包括:
|
||||
|
||||
- 🚀 **负载均衡式PD分解**:工业级解决方案,支持上下文缓存与动态实例角色切换,在保障SLO达标和吞吐量的同时优化资源利用率
|
||||
- 🔄 **统一KV缓存传输**:轻量级高性能传输库,支持智能NVLink/RDMA选择
|
||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
||||
|
||||
## 要求
|
||||
|
||||
- 操作系统: Linux
|
||||
- Python: 3.10 ~ 3.12
|
||||
|
||||
## 安装
|
||||
|
||||
FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU**、**天数(Iluvatar)GPU**、**燧原(Enflame)GCU**、**海光(Hygon)DCU** 以及其他硬件上进行推理部署。详细安装说明如下:
|
||||
|
||||
- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md)
|
||||
- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md)
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
## 入门指南
|
||||
|
||||
通过我们的文档了解如何使用 FastDeploy:
|
||||
- [10分钟快速部署](./docs/zh/get_started/quick_start.md)
|
||||
- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md)
|
||||
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
|
||||
- [离线推理](./docs/zh/offline_inference.md)
|
||||
- [在线服务](./docs/zh/online_serving/README.md)
|
||||
- [最佳实践](./docs/zh/best_practices/README.md)
|
||||
|
||||
## 支持模型列表
|
||||
|
||||
通过我们的文档了解如何下载模型,如何支持torch格式等:
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
|
||||
## 进阶用法
|
||||
|
||||
- [量化](./docs/zh/quantization/README.md)
|
||||
- [分离式部署](./docs/zh/features/disaggregated.md)
|
||||
- [投机解码](./docs/zh/features/speculative_decoding.md)
|
||||
- [前缀缓存](./docs/zh/features/prefix_caching.md)
|
||||
- [分块预填充](./docs/zh/features/chunked_prefill.md)
|
||||
|
||||
## 致谢
|
||||
|
||||
FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。
|
||||
6
benchmarks/yaml/eb45-8k-fp8-tp1-dp8_ep.yaml
Normal file
6
benchmarks/yaml/eb45-8k-fp8-tp1-dp8_ep.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
num_gpu_blocks_override: 1024
|
||||
max_model_len: 8192
|
||||
max_num_seqs: 64
|
||||
data_parallel_size: 8
|
||||
tensor_parallel_size: 1
|
||||
enable_expert_parallel: True
|
||||
8
benchmarks/yaml/request_yaml/x1.yaml
Normal file
8
benchmarks/yaml/request_yaml/x1.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 65535
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
10
benchmarks/yaml/x1-64k-w4a8c8-tp4.yaml
Normal file
10
benchmarks/yaml/x1-64k-w4a8c8-tp4.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
reasoning-parser: ernie_x1
|
||||
tool_call_parser: ernie_x1
|
||||
tensor_parallel_size: 4
|
||||
max_model_len: 65536
|
||||
max_num_seqs: 128
|
||||
enable_prefix_caching: True
|
||||
enable_chunked_prefill: True
|
||||
gpu_memory_utilization: 0.85
|
||||
use_cudagraph: True
|
||||
enable_custom_all_reduce: True
|
||||
30
build.sh
30
build.sh
@@ -34,7 +34,6 @@ EGG_DIR="fastdeploy.egg-info"
|
||||
|
||||
# custom_ops directory config
|
||||
OPS_SRC_DIR="custom_ops"
|
||||
OPS_TMP_DIR_BASE="tmp_base"
|
||||
OPS_TMP_DIR="tmp"
|
||||
|
||||
# command line log config
|
||||
@@ -71,25 +70,20 @@ function copy_ops(){
|
||||
PY_VERSION="py${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
|
||||
SYSTEM_VERSION=`${python} -c "import platform; print(platform.system().lower())"`
|
||||
PROCESSOR_VERSION=`${python} -c "import platform; print(platform.processor())"`
|
||||
WHEEL_BASE_NAME="fastdeploy_base_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
|
||||
WHEEL_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
|
||||
WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
|
||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||
if [ "$is_rocm" = "True" ]; then
|
||||
DEVICE_TYPE="rocm"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||
echo -e "ROCM ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
|
||||
if [ "$is_cuda" = "True" ]; then
|
||||
DEVICE_TYPE="gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "BASE and CUDA ops have been copy to fastdeploy"
|
||||
echo -e "CUDA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -112,9 +106,8 @@ function copy_ops(){
|
||||
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||
if [ "$if_corex" = "True" ]; then
|
||||
DEVICE_TYPE="iluvatar-gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||
echo -e "Iluvatar ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -126,20 +119,26 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
|
||||
if [ "$is_maca" = "True" ]; then
|
||||
DEVICE_TYPE="metax_gpu"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "MACA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu
|
||||
echo -e "BASE and CPU ops have been copy to fastdeploy"
|
||||
echo -e "CPU ops have been copy to fastdeploy"
|
||||
return
|
||||
}
|
||||
|
||||
function build_and_install_ops() {
|
||||
cd $OPS_SRC_DIR
|
||||
export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy}
|
||||
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..."
|
||||
${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE}
|
||||
find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \;
|
||||
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
|
||||
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
||||
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
||||
@@ -213,7 +212,6 @@ function cleanup() {
|
||||
fi
|
||||
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE
|
||||
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR
|
||||
}
|
||||
|
||||
|
||||
@@ -84,7 +84,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
seq_length,
|
||||
bsz);
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
padding_offset,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k};
|
||||
@@ -97,7 +96,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t> &seq_len_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
@@ -106,7 +105,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
const paddle::DataType &token_num_dtype,
|
||||
const paddle::DataType &seq_len_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
@@ -115,7 +113,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"cum_offsets_out",
|
||||
"padding_offset",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -19,10 +19,11 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
void RebuildPaddingCPUImpl(T *output_data,
|
||||
const T *input_data,
|
||||
const int *cum_offsets_data,
|
||||
const int *cu_seqlens_q_data,
|
||||
const int *seq_len_this_time_data,
|
||||
const int *seq_lens_decoder_data,
|
||||
const int *seq_lens_encoder_data,
|
||||
@@ -40,11 +41,12 @@ void RebuildPaddingCPUImpl(T *output_data,
|
||||
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
const int ori_token_idx =
|
||||
bi * max_input_length - cum_offsets_data[bi] + seq_id;
|
||||
|
||||
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
|
||||
const int src_offset = ori_token_idx * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
@@ -54,7 +56,7 @@ void RebuildPaddingCPUImpl(T *output_data,
|
||||
template <typename T>
|
||||
void RebuildAppendPaddingCPUImpl(T *output_data,
|
||||
const T *input_data,
|
||||
const int *cum_offsets_data,
|
||||
const int *cu_seqlens_q_data,
|
||||
const int *seq_len_this_time_data,
|
||||
const int *seq_lens_decoder_data,
|
||||
const int *seq_lens_encoder_data,
|
||||
@@ -69,30 +71,32 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
|
||||
int bi = ori_token_id / max_input_length;
|
||||
if (seq_len_this_time_data[bi] == 0 ||
|
||||
(seq_lens_decoder_data[bi] == 0 &&
|
||||
seq_lens_encoder_data[bi] == 0)) {
|
||||
continue;
|
||||
}
|
||||
seq_lens_encoder_data[bi] == 0)) {
|
||||
continue;
|
||||
}
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
|
||||
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
|
||||
int bias_idx = i % dim_embed;
|
||||
int src_offset = input_token_id * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
const paddle::Tensor &tmp_out,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
|
||||
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
|
||||
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_len_this_time_cpu =
|
||||
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_lens_decoder_cpu =
|
||||
@@ -107,7 +111,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
|
||||
int token_num = tmp_out_cpu.shape()[0];
|
||||
int dim_embed = tmp_out_cpu.shape()[1];
|
||||
int bsz = cum_offsets_cpu.shape()[0];
|
||||
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
|
||||
|
||||
paddle::Tensor out;
|
||||
if (output_padding_offset_cpu) {
|
||||
@@ -128,7 +132,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
|
||||
}
|
||||
|
||||
const int *cum_offsets_data = cum_offsets_cpu.data<int>();
|
||||
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
|
||||
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
|
||||
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
|
||||
@@ -141,7 +145,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cum_offsets_data,
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -154,7 +158,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
RebuildAppendPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cum_offsets_data,
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -167,7 +171,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cum_offsets_data,
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -186,7 +190,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cum_offsets_data,
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -198,7 +202,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
RebuildPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cum_offsets_data,
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -207,11 +211,10 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
|
||||
RebuildPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cum_offsets_data,
|
||||
cu_seqlens_q_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -230,7 +233,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
const std::vector<int64_t> &tmp_out_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &cu_seqlens_q_shape,
|
||||
const std::vector<int64_t> &seq_len_this_time_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
@@ -239,14 +242,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
if (output_padding_offset_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cum_offsets_shape[0];
|
||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||
return {{bsz, dim_embed}};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
const paddle::DataType &tmp_out_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &cu_seqlens_q_dtype,
|
||||
const paddle::DataType &seq_len_this_time_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
@@ -256,7 +259,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
|
||||
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
|
||||
.Inputs({"tmp_out",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"seq_len_this_time",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
|
||||
@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
|
||||
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
void AppendAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
@@ -60,6 +60,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
@@ -72,7 +73,11 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
@@ -118,27 +123,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
} else {
|
||||
qkv_out = qkv;
|
||||
}
|
||||
paddle::Tensor fmha_out;
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
D,
|
||||
qkv.place());
|
||||
}
|
||||
|
||||
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
|
||||
const paddle::Tensor& lambda_batch_ids,
|
||||
@@ -223,7 +207,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
main_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
};
|
||||
|
||||
if (qkv_out_scales) {
|
||||
@@ -339,7 +326,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -363,7 +353,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,8 +385,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
cudaStreamWaitEvent(main_stream, decoder_event);
|
||||
}
|
||||
}
|
||||
|
||||
return {fmha_out, qkv_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> AppendAttention(
|
||||
@@ -429,7 +420,11 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -464,8 +459,60 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
// template dtype generation
|
||||
phi::DataType dtype_id;
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
|
||||
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
dtype_id = phi::DataType::BFLOAT16;
|
||||
break;
|
||||
} else if (compute_dtype == "fp16") {
|
||||
dtype_id = phi::DataType::FLOAT16;
|
||||
break;
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
}
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// fmha_out generation, rewrite from AppendAttentionKernel
|
||||
paddle::Tensor fmha_out;
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
} else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
dtype_id,
|
||||
qkv.place());
|
||||
}
|
||||
|
||||
if (mask_offset) {
|
||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
@@ -487,6 +534,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
@@ -499,7 +547,11 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
@@ -514,20 +566,183 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
};
|
||||
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
switch (dtype_id){
|
||||
case phi::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
case phi::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
default:
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
|
||||
return {paddle::Tensor{}};
|
||||
}
|
||||
|
||||
void AppendAttentionWithOutput(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
const auto& key_cache_dims = key_cache.dims();
|
||||
meta_data.token_nums = qkv_dims[0];
|
||||
meta_data.kv_num_heads = key_cache_dims[1];
|
||||
meta_data.head_dims = key_cache_dims[3];
|
||||
// TODO: trick method support c4, add attr head_dims in the future
|
||||
if (cache_quant_type_str == "cache_int4_zp") {
|
||||
meta_data.head_dims *= 2;
|
||||
}
|
||||
const int total_num_head =
|
||||
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
|
||||
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
if (mask_offset) {
|
||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
};
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
|
||||
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
|
||||
case paddle::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
return dispatch_by_template(bp16_dtype);
|
||||
dispatch_by_template(bp16_dtype);
|
||||
break;
|
||||
} else if (compute_dtype == "fp16") {
|
||||
return dispatch_by_template(fp16_dtype);
|
||||
dispatch_by_template(fp16_dtype);
|
||||
break;
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
@@ -540,9 +755,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {paddle::Tensor{}};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
@@ -576,7 +791,11 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -600,7 +819,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
}
|
||||
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
|
||||
const int num_heads = total_num_head - 2 * kv_num_heads;
|
||||
return {{token_num, num_heads * head_dim}, qkv_shape};
|
||||
return {{token_num, num_heads * head_dim}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
@@ -636,7 +855,11 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -655,32 +878,148 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
if (compute_dtype == "bf16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::INT8};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
}
|
||||
} else if (compute_dtype == "fp16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::INT8};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::FLOAT16};
|
||||
}
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& set_max_lengths_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const std::vector<int64_t>& fmha_out_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
return {fmha_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
const paddle::DataType& qkv_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& set_max_lengths_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::DataType& fmha_out_dtype,
|
||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
return {fmha_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
@@ -714,11 +1053,15 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths"),
|
||||
paddle::Optional("kv_signal_data")})
|
||||
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"compute_type: std::string",
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
@@ -732,7 +1075,71 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool"})
|
||||
"speculate_decoder: bool",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"set_max_lengths",
|
||||
"max_len_kv",
|
||||
"fmha_out",
|
||||
paddle::Optional("rotary_embs"),
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("qkv_bias"),
|
||||
paddle::Optional("qkv_out_scales"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
paddle::Optional("cache_v_dequant_scales"),
|
||||
paddle::Optional("cache_k_zp"),
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths"),
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
|
||||
{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));
|
||||
|
||||
@@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: chunk_len) /
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
@@ -245,12 +247,16 @@ __global__ void multi_query_append_attention_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
s_frag);
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -406,6 +412,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -419,7 +427,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
||||
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
||||
@@ -502,7 +511,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -540,10 +549,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
kv_len - q_len,
|
||||
chunk_start)))
|
||||
: chunk_len) /
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -611,12 +619,15 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
s_frag);
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -882,6 +893,7 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -939,6 +951,7 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1061,12 +1074,18 @@ void MultiQueryAppendAttention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
|
||||
false,
|
||||
@@ -1104,6 +1123,9 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1116,7 +1138,8 @@ void MultiQueryAppendAttention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1161,8 +1184,8 @@ void MultiQueryAppendAttention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1172,6 +1195,9 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1184,7 +1210,8 @@ void MultiQueryAppendAttention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
@@ -1208,8 +1235,8 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1226,14 +1253,14 @@ void MultiQueryAppendAttention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1244,8 +1271,8 @@ void MultiQueryAppendAttention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
|
||||
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: chunk_len) /
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -333,12 +335,15 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
s_frag);
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
|
||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||
@@ -505,6 +510,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -518,7 +525,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||
@@ -627,7 +635,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -703,10 +711,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
kv_len - q_len,
|
||||
chunk_start)))
|
||||
: chunk_len) /
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -788,12 +795,15 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
s_frag);
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
|
||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||
@@ -1088,6 +1098,7 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1151,6 +1162,7 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1285,10 +1297,18 @@ void MultiQueryAppendC4Attention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1334,6 +1354,9 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1346,7 +1369,8 @@ void MultiQueryAppendC4Attention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1392,15 +1416,15 @@ void MultiQueryAppendC4Attention(
|
||||
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1410,6 +1434,9 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1422,7 +1449,8 @@ void MultiQueryAppendC4Attention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1445,8 +1473,8 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1463,14 +1491,14 @@ void MultiQueryAppendC4Attention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1481,8 +1509,8 @@ void MultiQueryAppendC4Attention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
|
||||
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: chunk_len) /
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -300,12 +302,15 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
s_frag);
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -474,6 +479,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -487,7 +494,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const uint32_t attn_mask_len = -1) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
@@ -601,7 +609,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -642,7 +650,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: chunk_len) /
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -728,12 +736,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
s_frag);
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -1054,6 +1066,7 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1111,6 +1124,7 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1254,10 +1268,17 @@ void MultiQueryAppendC8Attention(
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
uint32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1318,6 +1339,9 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1330,7 +1354,8 @@ void MultiQueryAppendC8Attention(
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1377,8 +1402,8 @@ void MultiQueryAppendC8Attention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1388,6 +1413,9 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1400,7 +1428,8 @@ void MultiQueryAppendC8Attention(
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num);
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1418,8 +1447,8 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1436,14 +1465,14 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1454,8 +1483,8 @@ void MultiQueryAppendC8Attention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
|
||||
@@ -905,12 +905,15 @@ template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
bool IS_SYSTEM = false>
|
||||
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
__device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t qo_idx_base,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t qo_len,
|
||||
const uint32_t kv_len,
|
||||
const uint32_t chunk_end,
|
||||
float (*s_frag)[num_frags_z][8]) {
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -924,10 +927,21 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
group_size,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
const bool out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||
} else {
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
out_of_boundary |= mask;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
|
||||
@@ -935,6 +949,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
||||
}
|
||||
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
|
||||
} else {
|
||||
const uint32_t q_idx = qo_idx_base,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
|
||||
@@ -18,6 +18,142 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadKVT cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec, k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
out_vec[2 * i] = src_vec[2 * i];
|
||||
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + kv_num_heads)) { // q k
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
|
||||
} else {
|
||||
// quant + write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -199,8 +335,9 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -244,6 +381,142 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
if (seq_lens_encoder[ori_bi] > 0) continue;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -266,7 +539,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
@@ -313,8 +587,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -382,7 +657,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<int, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
@@ -439,8 +715,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -512,7 +789,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -555,8 +833,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -633,10 +912,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
if constexpr (!is_scale_channel_wise) {
|
||||
scale = __ldg(&cache_k_scale[kv_head_idx]);
|
||||
}
|
||||
@@ -763,7 +1043,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -813,9 +1094,10 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -908,10 +1190,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
if constexpr (!is_scale_channel_wise) {
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
}
|
||||
@@ -1061,7 +1344,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1109,8 +1393,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -1191,10 +1476,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1364,7 +1650,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1424,8 +1711,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -1533,10 +1822,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1755,7 +2045,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1799,8 +2090,9 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1874,10 +2166,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -2054,7 +2347,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -2103,8 +2397,9 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -2191,10 +2486,11 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
&out_scale_vec2);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -2378,7 +2674,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -2425,8 +2722,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -2507,10 +2805,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx], &right_src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx + 8], &right_src_vec2);
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx + 8],
|
||||
@@ -2752,7 +3051,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads) {
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -2810,8 +3110,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -2920,10 +3221,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scale[left_cache_idx + 8],
|
||||
|
||||
@@ -15,6 +15,69 @@
|
||||
#include "decoder_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
@@ -34,6 +97,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -71,9 +135,32 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -91,7 +178,9 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -198,7 +287,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_neox_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -221,7 +311,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -248,7 +339,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -271,7 +363,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -335,7 +428,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int4_neox_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -360,7 +454,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -389,7 +484,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -414,7 +510,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -441,7 +538,10 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out) {
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -458,113 +558,93 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
bool is_scale_channel_wise = false;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
bool is_scale_channel_wise = false;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -596,49 +676,117 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
|
||||
"cache_int4_zp]");
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,7 +815,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void
|
||||
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -694,7 +845,10 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -720,7 +874,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -746,4 +903,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
@@ -40,4 +40,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);
|
||||
|
||||
@@ -33,7 +33,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int64_t elem_cnt,
|
||||
const int num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<int, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadScaleT = AlignedVector<float, VecSize>;
|
||||
@@ -64,6 +65,7 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias;
|
||||
const int64_t base_idx = token_idx * 3 * hidden_size + bias_idx;
|
||||
Load<int, VecSize>(&qkv[base_idx], &src_vec);
|
||||
@@ -72,8 +74,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
}
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
if (qkv_id < 2) {
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -115,7 +117,8 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int64_t elem_cnt,
|
||||
const int num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
@@ -142,11 +145,12 @@ __global__ void VariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t base_idx = token_idx * 3 * hidden_size +
|
||||
qkv_id * hidden_size + hi * last_dim + h_bias;
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
@@ -177,7 +181,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int64_t elem_cnt,
|
||||
const int num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<int, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadScaleT = AlignedVector<float, VecSize>;
|
||||
@@ -211,6 +216,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
|
||||
const int bias_idx_left =
|
||||
qkv_id * full_hidden_size + hi * last_dim + h_bias;
|
||||
const int bias_idx_right = bias_idx_left + half_lastdim;
|
||||
@@ -225,8 +231,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], &left_out_scale_vec);
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_right], &right_out_scale_vec);
|
||||
if (qkv_id < 2) {
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -269,7 +275,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int64_t elem_cnt,
|
||||
const int num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
@@ -297,6 +304,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left = token_idx * 3 * full_hidden_size +
|
||||
qkv_id * full_hidden_size + hi * last_dim +
|
||||
h_bias;
|
||||
@@ -304,8 +312,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
@@ -358,7 +366,7 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
@@ -367,6 +375,7 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
const int64_t base_idx = token_idx * offset + bias_idx;
|
||||
Load<int, VecSize>(&qkv[base_idx], &src_vec);
|
||||
@@ -375,8 +384,8 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
}
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -405,6 +414,97 @@ __global__ void GQAVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim,
|
||||
const bool rope_3d,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps
|
||||
) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec, k_norm_vec;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
}
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / last_dim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < q_num_head) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQAVariableLengthRotaryKernel(
|
||||
const T *qkv,
|
||||
@@ -514,6 +614,7 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
|
||||
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
const int64_t base_idx = token_idx * offset + bias_idx;
|
||||
Load<int, VecSize>(&qkv[base_idx], &src_vec);
|
||||
@@ -521,8 +622,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
@@ -599,14 +700,15 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
|
||||
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t bias_idx = hi * last_dim + h_bias;
|
||||
const int64_t base_idx = token_idx * offset + bias_idx;
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = qkv_biases ? static_cast<float>(src_vec[2 * i]+ bias_vec[2 * i]) : static_cast<float>(src_vec[2 * i]);
|
||||
@@ -654,7 +756,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<int, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadScaleT = AlignedVector<float, VecSize>;
|
||||
@@ -684,6 +787,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
|
||||
const int bias_idx_left = hi * last_dim + h_bias;
|
||||
const int bias_idx_right = bias_idx_left + half_lastdim;
|
||||
const int base_idx_left =
|
||||
@@ -698,8 +802,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], &left_out_scale_vec);
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_right], &right_out_scale_vec);
|
||||
if (hi < (q_num_head + kv_num_head)) {
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -745,7 +849,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
@@ -769,6 +874,7 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * last_dim + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
@@ -776,8 +882,76 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
@@ -1512,7 +1686,8 @@ void rotary_qk_variable(
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
VariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
@@ -1527,7 +1702,8 @@ void rotary_qk_variable(
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
const float *cos_emb = rotary_emb;
|
||||
@@ -1548,7 +1724,8 @@ void rotary_qk_variable(
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
NeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
@@ -1563,11 +1740,72 @@ void rotary_qk_variable(
|
||||
elem_nums,
|
||||
head_num,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void gqa_rotary_qk_norm_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
const QKV_TYPE *qkv_input, // qkv
|
||||
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||
const T *qkv_bias,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const int token_num,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false,
|
||||
const float *q_norm_weight = nullptr,
|
||||
const float *k_norm_weight = nullptr,
|
||||
const float rms_norm_eps = 1e-6) {
|
||||
int64_t elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
|
||||
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
|
||||
assert(dim_head == 128 && "dim_head must be 128");
|
||||
constexpr int HEAD_DIM = 128;
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
|
||||
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
|
||||
<<<grid_size, Block_Size, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void gqa_rotary_qk_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||
@@ -1585,6 +1823,7 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -1662,9 +1901,41 @@ void gqa_rotary_qk_variable(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -1680,7 +1951,9 @@ void gqa_rotary_qk_variable(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,38 +46,32 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out) {
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (num_heads == kv_num_heads) {
|
||||
rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (!is_scale_channel_wise) {
|
||||
gqa_rotary_qk_variable(
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
gqa_rotary_qk_norm_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
@@ -95,31 +89,81 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
rope_3d,
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
gqa_rotary_qk_quant_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
|
||||
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
PD_THROW(
|
||||
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
|
||||
}
|
||||
} else {
|
||||
if (num_heads == kv_num_heads) {
|
||||
rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (!is_scale_channel_wise) {
|
||||
gqa_rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
gqa_rotary_qk_quant_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
|
||||
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
const uint32_t block_size = meta_data.block_size;
|
||||
if (cache_quant_type_str == "none") {
|
||||
|
||||
@@ -289,7 +289,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
|
||||
@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
// do rope
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const bool rope_3d,
|
||||
const cudaStream_t &stream) {
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int PackSize = 16 / sizeof(T);
|
||||
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const std::string& cache_quant_type) {
|
||||
const std::string& cache_quant_type,
|
||||
const bool rope_3d) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.dims()[2],
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
|
||||
@@ -43,4 +43,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
@@ -27,6 +27,7 @@ struct AppendAttnMetaData {
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
const int *mask_offset = nullptr;
|
||||
};
|
||||
|
||||
__forceinline__ __host__ __device__ int div_up(int a, int b) {
|
||||
@@ -430,6 +431,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (group_size == 12) { \
|
||||
constexpr size_t GROUP_SIZE = 12; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 14) { \
|
||||
constexpr size_t GROUP_SIZE = 14; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
@@ -474,6 +478,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
if (causal) { \
|
||||
constexpr bool CAUSAL = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool CAUSAL = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
|
||||
@@ -559,3 +566,37 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
|
||||
convert_int8(result, source);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int kWarpSize = 32;
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
|
||||
*m2 += b_m2;
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
||||
*m2 = thread_m2;
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
|
||||
WelfordCombine1(b_m2, m2);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
|
||||
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T Rsqrt(T x);
|
||||
|
||||
template <>
|
||||
__inline__ __device__ float Rsqrt<float>(float x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ double Rsqrt<double>(double x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
||||
@@ -77,7 +77,54 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
const int max_input_length, const float quant_max_bound,
|
||||
const float quant_min_bound, const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int max_partition_size, const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num, const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
|
||||
const paddle::Tensor &encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &encoder_num_blocks,
|
||||
const paddle::Tensor &kv_batch_ids,
|
||||
const paddle::Tensor &kv_tile_ids_per_batch,
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
const paddle::optional<paddle::Tensor> &qkv_bias,
|
||||
const paddle::optional<paddle::Tensor> &qkv_out_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
const int max_input_length, const float quant_max_bound,
|
||||
@@ -107,7 +154,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const int kv_token_num, const int max_seq_len,
|
||||
const std::string &cache_quant_type);
|
||||
const std::string &cache_quant_type,
|
||||
const bool rope_3d);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||
@@ -124,11 +172,29 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
const std::string &quant_method, const int moe_topk,
|
||||
const bool norm_topk_prob, const bool group_moe);
|
||||
|
||||
std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule);
|
||||
|
||||
std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str);
|
||||
|
||||
std::vector<std::string> MacheteSupportedSchedules(
|
||||
std::string const& a_type_str, std::string const& b_type_str);
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||
const bool group_moe, const bool topk_only_mode);
|
||||
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||
@@ -188,7 +254,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
@@ -323,7 +390,7 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
|
||||
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
|
||||
const paddle::Tensor &mm_token_num_len,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
|
||||
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states);
|
||||
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
|
||||
const paddle::Tensor &topk_idx,
|
||||
@@ -497,6 +564,7 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -526,7 +594,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
|
||||
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
|
||||
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
|
||||
void dispose(int64_t _fa);
|
||||
@@ -548,6 +616,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
void clear_ipc_handles(int64_t _fa);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -609,7 +679,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
@@ -654,6 +724,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
// MTP
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -670,8 +754,10 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
@@ -761,6 +847,33 @@ void SpeculateStepPaddle(
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void MergePrefillDecodeOutput(
|
||||
const paddle::Tensor &encoder_res,
|
||||
const paddle::Tensor &decoder_res,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seq_q,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int max_token);
|
||||
|
||||
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_p,
|
||||
const paddle::optional<paddle::Tensor> &top_k,
|
||||
int64_t seed);
|
||||
|
||||
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_k);
|
||||
|
||||
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &min_p);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -814,6 +927,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* append_attention
|
||||
*/
|
||||
m.def("append_attention", &AppendAttention, "append attention function");
|
||||
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
|
||||
/**
|
||||
* gqa_rope_write_cache.cu
|
||||
* gqa_rope_write_cache
|
||||
@@ -845,7 +959,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
|
||||
py::arg("gating_output"), py::arg("gating_correction_bias"),
|
||||
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
|
||||
py::arg("topk_only_mode"), "moe export dispatch function");
|
||||
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/ep_moe_prefill_func.cu
|
||||
@@ -875,6 +989,27 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("recv_expert_count"), py::arg("block_size"),
|
||||
"per token per block quant");
|
||||
|
||||
#ifdef ENABLE_MACHETE
|
||||
/*machete/machete_mm.cu
|
||||
* machete_mm
|
||||
*/
|
||||
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
|
||||
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
|
||||
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
|
||||
py::arg("maybe_schedule"),
|
||||
"machete mm function");
|
||||
|
||||
/*machete/machete_prepack_B.cu
|
||||
* machete_prepack_B
|
||||
*/
|
||||
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
|
||||
|
||||
/*machete/machete_supported_schedules.cu
|
||||
* machete_supported_schedules
|
||||
*/
|
||||
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
|
||||
#endif
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_topk_select.cu
|
||||
* moe_topk_select
|
||||
@@ -1071,6 +1206,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
@@ -1088,7 +1225,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||
|
||||
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
@@ -1098,6 +1235,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
||||
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||
|
||||
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||
@@ -1111,4 +1250,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
||||
|
||||
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
|
||||
|
||||
m.def("rejection_top_p_sampling", &TopPSamplingReject, "rejection_top_p_sampling function");
|
||||
|
||||
m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function");
|
||||
|
||||
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
|
||||
|
||||
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
@@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa,
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
void clear_ipc_handles(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
fa->clear_ipc_handles();
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
@@ -163,3 +167,12 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(all_reduce)
|
||||
.Inputs({"inp",
|
||||
"out"})
|
||||
.Outputs({"new_out"})
|
||||
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
|
||||
.SetInplaceMap({{"out", "new_out"}})
|
||||
.SetKernelFn(PD_KERNEL(all_reduce));
|
||||
|
||||
@@ -517,10 +517,15 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
void clear_ipc_handles(){
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
ipc_handles_.clear();
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
// clang-format off
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
|
||||
@@ -20,7 +20,7 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
|
||||
int *mm_token_num_len,
|
||||
int *seq_lens_this_time,
|
||||
int *cu_seqlens_q,
|
||||
float *score_text,
|
||||
float *hidden_states,
|
||||
float *output,
|
||||
const int bsz,
|
||||
const int hidden_size) {
|
||||
@@ -32,14 +32,11 @@ __global__ void extract_text_token_output_kernel(int *max_seq_len,
|
||||
int max_seq_len_index_data = max_seq_len_index[0];
|
||||
int mm_token_num_len_data = mm_token_num_len[0];
|
||||
int true_bsz = cu_seqlens_q[bsz_index + 1] - 1;
|
||||
if (bsz_index >= max_seq_len_index_data) {
|
||||
true_bsz = true_bsz - mm_token_num_len_data;
|
||||
}
|
||||
if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) {
|
||||
output[bsz_index * hidden_size + block_idx] = 0.0;
|
||||
} else {
|
||||
if (seq_lens_this_time[bsz_index] != 0) {
|
||||
output[bsz_index * hidden_size + block_idx] = score_text[true_bsz * hidden_size + block_idx];
|
||||
output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -51,19 +48,19 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
|
||||
const paddle::Tensor& mm_token_num_len,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& score_text) {
|
||||
const paddle::Tensor& hidden_states) {
|
||||
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int hidden_size = score_text.shape()[1];
|
||||
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, score_text.place());
|
||||
const int hidden_size = hidden_states.shape()[1];
|
||||
paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place());
|
||||
|
||||
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, score_text.stream()>>>(
|
||||
extract_text_token_output_kernel<1024><<<hidden_size, 1024, 0, hidden_states.stream()>>>(
|
||||
const_cast<int*>(max_seq_len.data<int>()),
|
||||
const_cast<int*>(max_seq_len_index.data<int>()),
|
||||
const_cast<int*>(mm_token_num_len.data<int>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(cu_seqlens_q.data<int>()),
|
||||
const_cast<float*>(score_text.data<float>()),
|
||||
const_cast<float*>(hidden_states.data<float>()),
|
||||
output.data<float>(),
|
||||
bsz,
|
||||
hidden_size
|
||||
@@ -76,9 +73,9 @@ std::vector<std::vector<int64_t>> ExtractTextTokenOutputInferShape(const std::ve
|
||||
const std::vector<int64_t>& mm_token_num_len_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& score_text_shape) {
|
||||
const std::vector<int64_t>& hidden_states_shape) {
|
||||
const int bsz = seq_lens_this_time_shape[0];
|
||||
const int hidden_size = score_text_shape[1];
|
||||
const int hidden_size = hidden_states_shape[1];
|
||||
return {{bsz, hidden_size}};
|
||||
}
|
||||
|
||||
@@ -87,8 +84,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
|
||||
const paddle::DataType& mm_token_num_len_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& score_text_dtype) {
|
||||
return {score_text_dtype};
|
||||
const paddle::DataType& hidden_states_dtype) {
|
||||
return {hidden_states_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(extract_text_token_output)
|
||||
@@ -97,7 +94,7 @@ PD_BUILD_STATIC_OP(extract_text_token_output)
|
||||
"mm_token_num_len",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"score_text"})
|
||||
"hidden_states"})
|
||||
.Outputs({"output"})
|
||||
.SetKernelFn(PD_KERNEL(ExtractTextTokenOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape))
|
||||
|
||||
163
custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu
Normal file
163
custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu
Normal file
@@ -0,0 +1,163 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "flash_mask_attn_kernel.hpp"
|
||||
|
||||
template <typename paddle_type>
|
||||
struct cuteType;
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchFlashAttentionMask(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::optional<paddle::Tensor>& mask,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
const int batch_size = cu_seq_q.dims()[0];
|
||||
|
||||
paddle::Tensor out = paddle::empty(
|
||||
{q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place());
|
||||
|
||||
Flash_mask_params params;
|
||||
memset(¶ms, 0, sizeof(Flash_mask_params));
|
||||
|
||||
params.q_ptr = const_cast<T*>(q_input.data<T>());
|
||||
params.k_ptr = const_cast<T*>(k_input.data<T>());
|
||||
params.v_ptr = const_cast<T*>(v_input.data<T>());
|
||||
params.o_ptr = const_cast<T*>(out.data<T>());
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
|
||||
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_seq_len_q = max_enc_len_this_time;
|
||||
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
|
||||
params.batch_size = batch_size;
|
||||
params.gqa_group_size = head_num / kv_head_num;
|
||||
constexpr float kLog2e = 1.4426950408889634074;
|
||||
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
|
||||
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
|
||||
if (mask) {
|
||||
params.mask = const_cast<int*>(mask.get().data<int>());
|
||||
flash_attn_headdim128<kBlockM, kBlockN, true, cute_type>(params, 0);
|
||||
} else {
|
||||
flash_attn_headdim128<kBlockM, kBlockN, false, cute_type>(params, 0);
|
||||
}
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> FlashAttentionMask(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::optional<paddle::Tensor> &mask,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
return std::move(
|
||||
DispatchFlashAttentionMask<T>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
mask,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time));
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
return std::move(
|
||||
DispatchFlashAttentionMask<T>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
mask,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(flash_attention_mask)
|
||||
.Inputs({
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
paddle::Optional("mask")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_len: int",
|
||||
"max_enc_len_this_time: int",
|
||||
"max_dec_len_this_time: int"})
|
||||
.Outputs({
|
||||
"out"})
|
||||
.SetKernelFn(PD_KERNEL(FlashAttentionMask));
|
||||
231
custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp
Normal file
231
custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp
Normal file
@@ -0,0 +1,231 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_attn.hpp"
|
||||
#include "softmax.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <int kHeadDim>
|
||||
auto get_gmem_layout(int token_num, int head_num) {
|
||||
return make_layout(
|
||||
make_shape(token_num, kHeadDim, head_num),
|
||||
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
|
||||
}
|
||||
|
||||
template <typename Ktraits>
|
||||
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
|
||||
compute_attn_ws(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT Flash_mask_params const data_params) {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
using SoftType = ElementAccum;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
constexpr bool NeedMask = Ktraits::NeedMask;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
__align__(16) __shared__ int mask[kBlockM];
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
|
||||
|
||||
for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
|
||||
mask[i] = mask_this_batch[i];
|
||||
}
|
||||
}
|
||||
|
||||
const int seq_len_q = data_params.seq_len_encoder[bidb];
|
||||
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
|
||||
|
||||
if (m_block * kBlockM >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
shared_storage.barrier_Q.init(1);
|
||||
}
|
||||
|
||||
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
|
||||
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
const int real_seq = seq_len_q - m_block * kBlockM;
|
||||
|
||||
const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN);
|
||||
|
||||
if (warp_group_idx == 0) { // Producer
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
|
||||
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
|
||||
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
collective_mainloop.load(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_write_k,
|
||||
smem_pipe_write_v,
|
||||
shared_storage,
|
||||
n_block_max,
|
||||
m_block,
|
||||
bidh,
|
||||
bidb,
|
||||
data_params.cu_seq_q,
|
||||
data_params.cu_seq_k,
|
||||
seq_len_q,
|
||||
seq_len_k);
|
||||
}
|
||||
} else { // Consumer
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
|
||||
PipelineState smem_pipe_read_k, smem_pipe_read_v;
|
||||
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
|
||||
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_read_k,
|
||||
smem_pipe_read_v,
|
||||
tOrO,
|
||||
softmax,
|
||||
mask,
|
||||
n_block_max,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
m_block,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
shared_storage);
|
||||
|
||||
const int o_head_stride = data_params.head_num * kHeadDim;
|
||||
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
|
||||
|
||||
collective_mainloop.store<NumMmaThreads>(
|
||||
mainloop_params,
|
||||
tOrO,
|
||||
shared_storage,
|
||||
tiled_mma1,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
o_head_stride,
|
||||
real_seq,
|
||||
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_flash_mask(Flash_mask_params ¶ms, cudaStream_t stream) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(params.q_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num),
|
||||
static_cast<Element const*>(params.k_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
|
||||
static_cast<Element const*>(params.v_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
|
||||
params.scale_softmax_log2
|
||||
});
|
||||
|
||||
int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
|
||||
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)compute_attn_ws<Kernel_traits>;
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = num_blocks_m;
|
||||
grid_dims.y = params.head_num;
|
||||
grid_dims.z = params.batch_size;
|
||||
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
|
||||
}
|
||||
|
||||
template <int kBlockM, int kBlockN, bool NeedMask, typename InputType>
|
||||
void flash_attn_headdim128(Flash_mask_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
constexpr static int Headdim = 128;
|
||||
constexpr static int kNWarps = kBlockM / 16 + 4;
|
||||
constexpr static int kStages = 2;
|
||||
|
||||
using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>;
|
||||
run_flash_mask<Ktraits>(params, stream);
|
||||
}
|
||||
124
custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h
Normal file
124
custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h
Normal file
@@ -0,0 +1,124 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct Flash_mask_params {
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void * __restrict__ o_ptr;
|
||||
int * __restrict__ cu_seq_q;
|
||||
int * __restrict__ cu_seq_k;
|
||||
int * __restrict__ mask;
|
||||
int * seq_len_encoder;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_seq_len_q;
|
||||
int max_seq_len_k;
|
||||
int batch_size;
|
||||
int gqa_group_size;
|
||||
float scale_softmax_log2;
|
||||
};
|
||||
|
||||
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
|
||||
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
|
||||
struct SharedStorageQKVO {
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
union {
|
||||
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
cutlass::arch::ClusterTransactionBarrier barrier_Q;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool NeedMask_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_mask_kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int32_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
|
||||
static constexpr int kStages = kStages_;
|
||||
static constexpr int NeedMask = NeedMask_;
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
using TiledMma0 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
using TiledMma1 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutV =
|
||||
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
|
||||
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
|
||||
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
|
||||
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyOValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
TiledCopyOAtom{},
|
||||
TiledCopyOThrLayout{},
|
||||
TiledCopyOValLayout{}
|
||||
));
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
};
|
||||
431
custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp
Normal file
431
custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp
Normal file
@@ -0,0 +1,431 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopAttn {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr bool NeedMask = Ktraits::NeedMask;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
|
||||
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
|
||||
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
|
||||
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutV = SmemLayoutK;
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVt =
|
||||
decltype(cute::composition(SmemLayoutV{},
|
||||
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
|
||||
using SmemLayoutO = typename Ktraits::SmemLayoutO;
|
||||
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
|
||||
|
||||
using TMA_Q = decltype(make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{})); // no mcast for Q
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
|
||||
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
Element const* ptr_Q;
|
||||
LayoutT layout_Q;
|
||||
Element const* ptr_K;
|
||||
LayoutT layout_K;
|
||||
Element const* ptr_V;
|
||||
LayoutT layout_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutT layout_Q;
|
||||
LayoutT layout_K;
|
||||
LayoutT layout_V;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
|
||||
TMA_Q tma_load_Q = make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
mQ,
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{});
|
||||
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
|
||||
TMA_KV tma_load_K = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mK,
|
||||
SmemLayoutK{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
|
||||
TMA_KV tma_load_V = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mV,
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.layout_Q, args.layout_K, args.layout_V,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
|
||||
tma_load_Q, tma_load_K, tma_load_V,
|
||||
args.softmax_scale_log2};
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(
|
||||
const MTensor &m_tensor,
|
||||
const Shape &tile_shape,
|
||||
const int *cu_seq_len,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int actual_seq_len) const {
|
||||
auto g_offset = local_tile(
|
||||
m_tensor(_, _, bidh),
|
||||
cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(cu_seq_len[bidb], _0{}));
|
||||
auto g_sequence = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
|
||||
g_offset.stride()
|
||||
));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_write_k,
|
||||
PipelineState& smem_pipe_write_v,
|
||||
SharedStorage &shared_storage,
|
||||
const int n_block_max,
|
||||
const int m_block,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
|
||||
Tensor gQ = get_local_tile_tensor(
|
||||
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
|
||||
Tensor gK = get_local_tile_tensor(
|
||||
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
Tensor gV = get_local_tile_tensor(
|
||||
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
|
||||
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
|
||||
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
|
||||
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
|
||||
|
||||
uint16_t mcast_mask_kv = 0;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
||||
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
|
||||
}
|
||||
|
||||
|
||||
if (lane_predicate) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
|
||||
tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
}
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; --n_block) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
|
||||
tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
|
||||
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
if (lane_predicate) {
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
|
||||
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_read_k,
|
||||
PipelineState& smem_pipe_read_v,
|
||||
FrgTensorO& tOrO,
|
||||
Softmax& softmax,
|
||||
const int *mask,
|
||||
const int n_block_max,
|
||||
const int thread_idx,
|
||||
const int m_block,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k,
|
||||
SharedStorage& shared_storage) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
|
||||
|
||||
typename Ktraits::TiledMma0 tiled_mma0;
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
|
||||
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMma0.partition_fragment_B(sK);
|
||||
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
|
||||
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k);
|
||||
++smem_pipe_read_k;
|
||||
|
||||
int mask_start_idx;
|
||||
int mask_row_id;
|
||||
int col_base;
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
const int lane_id = thread_idx % 32;
|
||||
mask_start_idx = mask[0] / kBlockN - 1;
|
||||
|
||||
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
|
||||
|
||||
col_base = thread_idx % 4 * 2;
|
||||
|
||||
app_mask(
|
||||
tSrS,
|
||||
mask,
|
||||
mask_row_id,
|
||||
col_base + n_block * kBlockN);
|
||||
} else {
|
||||
auto col_limit_causal = [&](int row, int n_block) {
|
||||
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
|
||||
};
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
||||
Tensor tScS = threadMma0.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if (int(get<1>(tScS(i))) >=
|
||||
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
|
||||
tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
|
||||
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
||||
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
||||
clear(scores_scale);
|
||||
|
||||
#pragma unroll 1
|
||||
for (; n_block > 0; --n_block) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
if (n_block >= mask_start_idx) {
|
||||
app_mask(
|
||||
tSrS,
|
||||
mask,
|
||||
mask_row_id,
|
||||
col_base + n_block * kBlockN);
|
||||
}
|
||||
}
|
||||
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
||||
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
}
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v);
|
||||
++smem_pipe_read_v;
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
return;
|
||||
}
|
||||
|
||||
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
FrgTensorO const& tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
const int o_head_stride,
|
||||
const int real_seq,
|
||||
T * out_ptr) {
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<Element>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
|
||||
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(o_head_stride, _1{}));
|
||||
|
||||
GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
|
||||
|
||||
if (real_seq >= kBlockM) {
|
||||
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
206
custom_ops/gpu_ops/flash_mask_attn/softmax.hpp
Normal file
206
custom_ops/gpu_ops/flash_mask_attn/softmax.hpp
Normal file
@@ -0,0 +1,206 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
|
||||
}
|
||||
|
||||
__forceinline__ __device__ __half2 half_exp(__half2 x) {
|
||||
uint32_t tmp_out, tmp_in;
|
||||
tmp_in = reinterpret_cast<uint32_t&>(x);
|
||||
asm ("ex2.approx.f16x2 %0, %1;\n"
|
||||
: "=r"(tmp_out)
|
||||
: "r"(tmp_in));
|
||||
__half2 out = reinterpret_cast<__half2&>(tmp_out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const float max_scaled = max(mi) * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
CUTLASS_DEVICE Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<bool Is_first, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT scores_scale;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.0f / sum;
|
||||
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
|
||||
scores_scale(mi) = inv_sum;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
453
custom_ops/gpu_ops/flash_mask_attn/utils.hpp
Normal file
453
custom_ops/gpu_ops/flash_mask_attn/utils.hpp
Normal file
@@ -0,0 +1,453 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
__forceinline__ __device__ auto float_2_half2(const float x) {
|
||||
if constexpr (std::is_same<T, cutlass::half_t>::value) {
|
||||
return __float2half2_rn(x);
|
||||
} else {
|
||||
return __float2bfloat162_rn(x);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct uint16 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
};
|
||||
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
template<int BYTES>
|
||||
struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<64> {
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
template<typename Elt_type, uint32_t NUM_ELT>
|
||||
struct Vec {
|
||||
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
|
||||
Alias_type data;
|
||||
|
||||
inline __device__ Vec() {}
|
||||
|
||||
template<typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void load_from(const void *base_ptr) {
|
||||
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
|
||||
}
|
||||
|
||||
|
||||
inline __device__ void store_to(void *base_ptr) {
|
||||
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
|
||||
}
|
||||
|
||||
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
|
||||
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void set_zero() {
|
||||
constexpr int size = sizeof(Vec_type) / sizeof(int);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; ++i) {
|
||||
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, int PackSize>
|
||||
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
|
||||
static_assert(PackSize % 2 == 0);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < PackSize / 2; i++) {
|
||||
const float cos_inv_freq = cos.data.elt[i];
|
||||
const float sin_inv_freq = sin.data.elt[i];
|
||||
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
|
||||
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
|
||||
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
|
||||
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
__forceinline__ __device__ void app_mask(
|
||||
Tensor &tSrS,
|
||||
const int *mask,
|
||||
const int &mask_row_id,
|
||||
const int &col_base) {
|
||||
const float mask_value = -1000000.0f;
|
||||
for (int i = 0; i < size(tSrS); i+=8) {
|
||||
const int col = i * 2 + col_base;
|
||||
if (col >= mask[mask_row_id]) {
|
||||
tSrS(i) = mask_value;
|
||||
}
|
||||
if (col + 1 >= mask[mask_row_id]) {
|
||||
tSrS(i + 1) = mask_value;
|
||||
}
|
||||
if (col >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 2) = mask_value;
|
||||
}
|
||||
if (col + 1 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 3) = mask_value;
|
||||
}
|
||||
if (col + 8 >= mask[mask_row_id]) {
|
||||
tSrS(i + 4) = mask_value;
|
||||
}
|
||||
if (col + 9 >= mask[mask_row_id]) {
|
||||
tSrS(i + 5) = mask_value;
|
||||
}
|
||||
if (col + 8 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 6) = mask_value;
|
||||
}
|
||||
if (col + 9 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 7) = mask_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct HalfMax;
|
||||
template<>
|
||||
struct HalfMax<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMax<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct HalfMin;
|
||||
template<>
|
||||
struct HalfMin<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("min.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMin<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("min.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
|
||||
__forceinline__ __device__ void copy(
|
||||
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &identity_MN,
|
||||
const int max_MN = 0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template<typename T, typename ReductionOp, int block_size>
|
||||
__inline__ __device__ T BlockAllReduce(T val) {
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
|
||||
if (threadIdx.x == 0) { result_broadcast = result; }
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
}
|
||||
|
||||
template<typename T, int block_size>
|
||||
__inline__ __device__ T BlockScanSum(T val) {
|
||||
typedef cub::BlockScan<int, block_size> BlockScanT;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
|
||||
int aggregate;
|
||||
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
|
||||
__syncthreads();
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct MinOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename ReductionOp, int thread_group_width = 32>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
ReductionOp op;
|
||||
#pragma unroll
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
@@ -109,11 +109,11 @@ void GetOutputEp(const paddle::Tensor& x,
|
||||
return;
|
||||
}
|
||||
|
||||
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
|
||||
void GetOutputEPStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
|
||||
GetOutputEp(x, rank_id, wait_flag, 1);
|
||||
}
|
||||
|
||||
void GetOutputDynamic(const paddle::Tensor& x,
|
||||
void GetOutputEPDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
@@ -125,11 +125,11 @@ PD_BUILD_STATIC_OP(get_output_ep)
|
||||
.Attrs({"rank_id: int64_t", "wait_flag: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputStatic));
|
||||
.SetKernelFn(PD_KERNEL(GetOutputEPStatic));
|
||||
|
||||
PD_BUILD_STATIC_OP(get_output_ep_dynamic)
|
||||
.Inputs({"x"})
|
||||
.Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputDynamic));
|
||||
.SetKernelFn(PD_KERNEL(GetOutputEPDynamic));
|
||||
|
||||
@@ -46,7 +46,11 @@ __global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
#else
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||
#endif
|
||||
}
|
||||
if (ti == 0) {
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
@@ -101,7 +105,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length);
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||
@@ -114,7 +117,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t> &seq_len_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
@@ -123,7 +126,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
const paddle::DataType &token_num_dtype,
|
||||
const paddle::DataType &seq_len_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
@@ -132,7 +134,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"cum_offsets_out",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
|
||||
@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -193,6 +221,12 @@ public:
|
||||
typedef uint8_t data_t;
|
||||
};
|
||||
|
||||
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||
public:
|
||||
typedef __nv_fp8_e4m3 DataType;
|
||||
typedef paddle::float8_e4m3fn data_t;
|
||||
};
|
||||
|
||||
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
@@ -509,6 +543,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_HIP
|
||||
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
||||
int mode = 0) {
|
||||
uint32_t flag;
|
||||
@@ -541,7 +576,7 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
|
||||
"l"(flag_addr));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
@@ -556,3 +591,28 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
574
custom_ops/gpu_ops/machete/generate.py
Normal file
574
custom_ops/gpu_ops/machete/generate.py
Normal file
@@ -0,0 +1,574 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import jinja2
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
p = os.path.abspath(os.path.join(cur_dir, "../../third_party/cutlass/python"))
|
||||
sys.path.insert(0, p)
|
||||
|
||||
from cutlass_library import (
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType,
|
||||
)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from machete_cutlass_library_extension import (
|
||||
DataType,
|
||||
MACHETEDataType,
|
||||
MACHETEDataTypeMACHETEScalarTypeTag,
|
||||
MACHETEDataTypeNames,
|
||||
MACHETEDataTypePaddleDataTypeTag,
|
||||
MACHETEDataTypeSize,
|
||||
MACHETEDataTypeTag,
|
||||
MACHETEKernelScheduleTag,
|
||||
MixedInputKernelScheduleType,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
#
|
||||
# Generator templating
|
||||
#
|
||||
|
||||
DISPATCH_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
||||
{% for s in impl_config.schedules %}
|
||||
extern paddle::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
||||
{%- endfor %}
|
||||
|
||||
paddle::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
[[maybe_unused]] auto M = args.A.shape()[0];
|
||||
[[maybe_unused]] auto N = args.B.shape()[1];
|
||||
[[maybe_unused]] auto K = args.A.shape()[1];
|
||||
|
||||
if (!args.maybe_schedule) {
|
||||
{%- for cond, s in impl_config.heuristic %}
|
||||
{%if cond is not none%}if ({{cond}})
|
||||
{%- else %}else
|
||||
{%- endif %}
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
||||
}
|
||||
|
||||
{%- for s in impl_config.schedules %}
|
||||
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
||||
{%- endfor %}
|
||||
PADDLE_ENFORCE(false, "machete_gemm(..) is not implemented ");
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
|
||||
static inline std::optional<paddle::DataType> maybe_scalartype(
|
||||
std::optional<paddle::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return t->dtype();
|
||||
};
|
||||
}
|
||||
|
||||
paddle::Tensor mm_dispatch(MMArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.A.dtype());
|
||||
auto a_type = args.A.dtype();
|
||||
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
||||
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
||||
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
||||
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
|
||||
&& a_type == {{PaddleTypeTag[t.a]}}
|
||||
&& out_type == {{PaddleTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
maybe_g_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!maybe_g_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void -%}
|
||||
maybe_g_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!maybe_g_zeros_type{%endif%}
|
||||
&& {%if t.b_channel_scale != void -%}
|
||||
maybe_ch_scales_type == {{PaddleTypeTag[t.b_channel_scale]}}
|
||||
{%- else %}!maybe_ch_scales_type{%endif%}
|
||||
&& {%if t.a_token_scale != void -%}
|
||||
maybe_tok_scales_type == {{PaddleTypeTag[t.a_token_scale]}}
|
||||
{%- else %}!maybe_tok_scales_type{%endif%}
|
||||
) {
|
||||
return mm_dispatch_{{type_sig}}(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
false, "machete_mm(..) is not implemented "
|
||||
"; implemented types are: \\n",
|
||||
{%- for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
"\\t{{gen_type_option_name(t)}}\\n",
|
||||
{%- endfor %}
|
||||
"");
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
|
||||
&& args.a_type == {{PaddleTypeTag[t.a]}}
|
||||
&& out_type == {{PaddleTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
args.maybe_group_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!args.maybe_group_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void-%}
|
||||
args.maybe_group_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
||||
) {
|
||||
return {
|
||||
{%- for s in impl_config.schedules %}
|
||||
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
};
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
IMPL_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
{% for sch in unique_schedules(impl_configs) %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
struct sch_{{sch_sig}} {
|
||||
using TileShapeNM = Shape<{{
|
||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||
using ClusterShape = Shape<{{
|
||||
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
|
||||
// TODO: Reimplement
|
||||
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
|
||||
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
|
||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
};
|
||||
{% endfor %}
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
|
||||
template<typename Sch>
|
||||
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[t.b]}}, // ElementB
|
||||
{{DataTypeTag[t.out]}}, // ElementD
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
paddle::Tensor
|
||||
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
||||
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
PREPACK_TEMPLATE = """
|
||||
#include "../machete_prepack_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
paddle::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
||||
{%- for t in types %}
|
||||
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
||||
if (args.a_type == {{PaddleTypeTag[t.a]}}
|
||||
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
||||
&& convert_type == {{PaddleTypeTag[t.convert]}}) {
|
||||
return prepack_impl<
|
||||
PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[b_type]}}, // ElementB
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
PADDLE_ENFORCE(false,
|
||||
"prepack_B_dispatch(..) is not implemented");
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: tuple[int, int]
|
||||
cluster_shape_mnk: tuple[int, int, int]
|
||||
kernel_schedule: MixedInputKernelScheduleType
|
||||
epilogue_schedule: EpilogueScheduleType
|
||||
tile_scheduler: TileSchedulerType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeConfig:
|
||||
a: DataType
|
||||
b: Union[DataType, MACHETEDataType]
|
||||
b_group_scale: DataType
|
||||
b_group_zeropoint: DataType
|
||||
b_channel_scale: DataType
|
||||
a_token_scale: DataType
|
||||
out: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PrepackTypeConfig:
|
||||
a: DataType
|
||||
b_num_bits: int
|
||||
convert: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
types: TypeConfig
|
||||
schedules: list[ScheduleConfig]
|
||||
heuristic: list[tuple[Optional[str], ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
cluster_shape = (
|
||||
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||
)
|
||||
kernel_schedule = MACHETEKernelScheduleTag[schedule_config.kernel_schedule].split("::")[-1]
|
||||
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split("::")[-1]
|
||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||
|
||||
return f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + f"_{epilogue_schedule}_{tile_scheduler}"
|
||||
|
||||
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
sch_sig = generate_sch_sig(schedule_config)
|
||||
for orig, terse in kernel_terse_names_replace.items():
|
||||
sch_sig = sch_sig.replace(orig, terse)
|
||||
return sch_sig
|
||||
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_types: TypeConfig):
|
||||
return str("".join([MACHETEDataTypeNames[getattr(kernel_types, field.name)] for field in fields(TypeConfig)]))
|
||||
|
||||
|
||||
def generate_type_option_name(kernel_types: TypeConfig):
|
||||
return ", ".join(
|
||||
[
|
||||
f"{field.name.replace('b_', 'with_')+'_type'}=" + MACHETEDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return (n != 0) and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def to_cute_constant(value: list[int]):
|
||||
|
||||
def _to_cute_constant(value: int):
|
||||
if is_power_of_two(value):
|
||||
return f"_{value}"
|
||||
else:
|
||||
return f"Int<{value}>"
|
||||
|
||||
if isinstance(value, Iterable):
|
||||
return [_to_cute_constant(value) for value in value]
|
||||
else:
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||
return list(set(sch for impl_config in impl_configs for sch in impl_config.schedules))
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
return {
|
||||
4: DataType.u4,
|
||||
8: DataType.u8,
|
||||
16: DataType.u16,
|
||||
32: DataType.u32,
|
||||
64: DataType.u64,
|
||||
}[num_bits]
|
||||
|
||||
|
||||
template_globals = {
|
||||
"void": DataType.void,
|
||||
"DataTypeTag": MACHETEDataTypeTag,
|
||||
"MACHETEScalarTypeTag": MACHETEDataTypeMACHETEScalarTypeTag,
|
||||
"PaddleTypeTag": MACHETEDataTypePaddleDataTypeTag,
|
||||
"KernelScheduleTag": MACHETEKernelScheduleTag,
|
||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||
"TileSchedulerTag": TileSchedulerTag,
|
||||
"to_cute_constant": to_cute_constant,
|
||||
"gen_sch_sig": generate_terse_sch_sig,
|
||||
"gen_type_sig": generate_type_signature,
|
||||
"unique_schedules": unique_schedules,
|
||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||
"gen_type_option_name": generate_type_option_name,
|
||||
}
|
||||
|
||||
|
||||
def create_template(template_str):
|
||||
template = jinja2.Template(template_str)
|
||||
template.globals.update(template_globals)
|
||||
return template
|
||||
|
||||
|
||||
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
|
||||
mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
sources.append(
|
||||
(
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
)
|
||||
)
|
||||
|
||||
prepack_types = []
|
||||
for impl_config in impl_configs:
|
||||
convert_type = (
|
||||
impl_config.types.a
|
||||
if impl_config.types.b_group_scale == DataType.void
|
||||
else impl_config.types.b_group_scale
|
||||
)
|
||||
prepack_types.append(
|
||||
PrepackTypeConfig(
|
||||
a=impl_config.types.a,
|
||||
b_num_bits=MACHETEDataTypeSize[impl_config.types.b],
|
||||
convert=convert_type,
|
||||
accumulator=impl_config.types.accumulator,
|
||||
)
|
||||
)
|
||||
|
||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||
# For now we we can just use the first accumulator type seen since
|
||||
# the tensor core shapes/layouts don't vary based on accumulator
|
||||
# type so we can generate less code this way
|
||||
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
||||
|
||||
unique_prepack_types = []
|
||||
prepack_types_seen = set()
|
||||
for prepack_type in prepack_types:
|
||||
key = prepacked_type_key(prepack_type)
|
||||
if key not in prepack_types_seen:
|
||||
unique_prepack_types.append(prepack_type)
|
||||
prepack_types_seen.add(key)
|
||||
|
||||
sources.append(
|
||||
(
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(
|
||||
types=unique_prepack_types,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Split up impls across files
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||
|
||||
files_impls: list[list[ImplConfig]] = [[]]
|
||||
|
||||
curr_num_impls_assigned = 0
|
||||
curr_impl_in_file = 0
|
||||
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
||||
|
||||
while curr_num_impls_assigned < num_impls:
|
||||
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
||||
if room_left_in_file == 0:
|
||||
files_impls.append([])
|
||||
room_left_in_file = num_impls_per_file
|
||||
curr_impl_in_file = 0
|
||||
|
||||
curr_ic = curr_impl_configs[-1]
|
||||
if len(curr_ic.schedules) >= room_left_in_file:
|
||||
# Break apart the current impl config
|
||||
tmp_ic = deepcopy(curr_ic)
|
||||
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
||||
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
||||
files_impls[-1].append(tmp_ic)
|
||||
else:
|
||||
files_impls[-1].append(curr_ic)
|
||||
curr_impl_configs.pop()
|
||||
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||
|
||||
for part, file_impls in enumerate(files_impls):
|
||||
sources.append(
|
||||
(
|
||||
f"machete_mm_impl_part{part+1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
)
|
||||
)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def generate():
|
||||
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
sch_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
default_tile_heuristic_config = {
|
||||
# M = 257+
|
||||
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
"M > 256": ((128, 256), (2, 1, 1)),
|
||||
# M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
"M > 128": ((128, 256), (2, 1, 1)),
|
||||
# M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
# M = 33-64
|
||||
"M > 40 && K <= 6144 && N <= 6144": ((128, 32), (2, 1, 1)),
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
# M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
# M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
return schedules
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
)
|
||||
for b in (MACHETEDataType.u4b8, MACHETEDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16)
|
||||
)
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(
|
||||
GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic),
|
||||
)
|
||||
]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
# Delete the "generated" directory if it exists
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Create the "generated" directory
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Render each group of configurations into separate files
|
||||
for filename, code in create_sources(impl_configs):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate()
|
||||
31
custom_ops/gpu_ops/machete/machete_collective_builder.cuh
Normal file
31
custom_ops/gpu_ops/machete/machete_collective_builder.cuh
Normal file
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils/machete_collective_builder.cuh"
|
||||
#include "machete_mainloop.cuh"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
struct MacheteKernelTag {};
|
||||
|
||||
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
|
||||
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct MacheteCollectiveBuilder<
|
||||
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
|
||||
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from typing import Union
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
KernelScheduleTag,
|
||||
KernelScheduleType,
|
||||
enum_auto,
|
||||
)
|
||||
|
||||
#
|
||||
# Extend cutlass library with custom types, and missing values
|
||||
#
|
||||
|
||||
|
||||
class MACHETEDataType(enum.Enum):
|
||||
u4b8 = enum_auto()
|
||||
u8b128 = enum_auto()
|
||||
|
||||
|
||||
class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedPingpong = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
|
||||
MACHETEDataTypeNames: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
**DataTypeNames, # type: ignore
|
||||
**{
|
||||
MACHETEDataType.u4b8: "u4b8",
|
||||
MACHETEDataType.u8b128: "u8b128",
|
||||
},
|
||||
}
|
||||
|
||||
MACHETEDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
**DataTypeTag, # type: ignore
|
||||
**{
|
||||
MACHETEDataType.u4b8: "cutlass::machete_uint4b8_t",
|
||||
MACHETEDataType.u8b128: "cutlass::machete_uint8b128_t",
|
||||
},
|
||||
}
|
||||
|
||||
MACHETEDataTypeSize: dict[Union[MACHETEDataType, DataType], int] = {
|
||||
**DataTypeSize, # type: ignore
|
||||
**{
|
||||
MACHETEDataType.u4b8: 4,
|
||||
MACHETEDataType.u8b128: 8,
|
||||
},
|
||||
}
|
||||
|
||||
MACHETEDataTypeMACHETEScalarTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
MACHETEDataType.u4b8: "machete::kU4B8",
|
||||
MACHETEDataType.u8b128: "machete::kU8B128",
|
||||
DataType.u4: "machete::kU4",
|
||||
DataType.u8: "machete::kU8",
|
||||
DataType.s4: "machete::kS4",
|
||||
DataType.s8: "machete::kS8",
|
||||
DataType.f16: "machete::kFloat16",
|
||||
DataType.bf16: "machete::kBfloat16",
|
||||
}
|
||||
|
||||
MACHETEDataTypePaddleDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
DataType.u8: "paddle::DataType::UINT8",
|
||||
DataType.s8: "paddle::DataType::INT8",
|
||||
DataType.e4m3: "paddle::DataType::FLOAT8_E4M3FN",
|
||||
DataType.s32: "paddle::DataType::INT32",
|
||||
DataType.f16: "paddle::DataType::FLOAT16",
|
||||
DataType.bf16: "paddle::DataType::BFLOAT16",
|
||||
DataType.f32: "paddle::DataType::FLOAT32",
|
||||
}
|
||||
|
||||
MACHETEKernelScheduleTag: dict[Union[MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
},
|
||||
}
|
||||
35
custom_ops/gpu_ops/machete/machete_interleaving_utils.cuh
Normal file
35
custom_ops/gpu_ops/machete/machete_interleaving_utils.cuh
Normal file
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// get an interleaved block layout where each element consecutive element has a
|
||||
// stride of bit_stride and the block width is blk_bit_width,
|
||||
// examples:
|
||||
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
|
||||
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
|
||||
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
|
||||
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
|
||||
template <typename T, int bit_stride, int blk_bit_width>
|
||||
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
|
||||
static_assert(blk_bit_width % bit_stride == 0);
|
||||
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
|
||||
|
||||
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
|
||||
|
||||
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
|
||||
// identity layout
|
||||
return Layout<Shape<Int<elems_per_blk>>>{};
|
||||
} else {
|
||||
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
|
||||
constexpr auto num_strides = elems_per_blk / elems_per_stride;
|
||||
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
|
||||
Stride<Int<elems_per_stride>, Int<1>>>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
1473
custom_ops/gpu_ops/machete/machete_mainloop.cuh
Normal file
1473
custom_ops/gpu_ops/machete/machete_mainloop.cuh
Normal file
File diff suppressed because it is too large
Load Diff
88
custom_ops/gpu_ops/machete/machete_mm.cu
Normal file
88
custom_ops/gpu_ops/machete/machete_mm.cu
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
|
||||
template <typename T>
|
||||
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
|
||||
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
|
||||
}
|
||||
|
||||
paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
std::optional<paddle::DataType> const& maybe_out_type,
|
||||
std::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
std::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
int64_t maybe_group_size,
|
||||
std::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
.b_type = b_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
.maybe_group_scales = maybe_group_scales,
|
||||
.maybe_group_zeros = maybe_group_zeros,
|
||||
.maybe_group_size = maybe_group_size_opt,
|
||||
.maybe_channel_scales = maybe_channel_scales,
|
||||
.maybe_token_scales = maybe_token_scales,
|
||||
.maybe_schedule = maybe_schedule_opt});
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule
|
||||
) {
|
||||
|
||||
machete::ScalarTypeId b_type_id;
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
if (maybe_out_type_str == "float16") {
|
||||
maybe_out_type = paddle::DataType::FLOAT16;
|
||||
} else if (maybe_out_type_str == "bfloat16") {
|
||||
maybe_out_type = paddle::DataType::BFLOAT16;
|
||||
} else {
|
||||
maybe_out_type = A.dtype();
|
||||
}
|
||||
auto out = mm(A, B, b_type_id, maybe_out_type,
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_group_scales),
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_group_zeros),
|
||||
maybe_group_size,
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_channel_scales),
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_token_scales),
|
||||
maybe_schedule);
|
||||
return {out};
|
||||
}
|
||||
305
custom_ops/gpu_ops/machete/machete_mm_kernel.cuh
Normal file
305
custom_ops/gpu_ops/machete/machete_mm_kernel.cuh
Normal file
@@ -0,0 +1,305 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
// The cutlass include order matters (annoyingly)
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
// clang-format on
|
||||
|
||||
#include "utils/cute_utils.cuh"
|
||||
#include "utils/machete_numeric_conversion.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_prepacked_layout.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
|
||||
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
|
||||
// instructions only support sourcing from registers for the left-hand
|
||||
// operand, we want to upconvert/decompress the quantized operand in
|
||||
// register. Since the primary use case we want to support is Y = XW^t where
|
||||
// W is quantized, in this situation or right-hand operand is quantized so
|
||||
// we compute the transpose to move it to the left-hand side.
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||
typename ScheduleConfig>
|
||||
struct MacheteKernelTemplate {
|
||||
static constexpr bool with_C = false; // not ever used
|
||||
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
|
||||
static constexpr bool with_group_zeropoints =
|
||||
!std::is_same_v<GroupZeroT, void>;
|
||||
static constexpr bool with_channel_scales =
|
||||
!std::is_same_v<ChannelScaleT, void>;
|
||||
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
|
||||
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementCompute = AccumulatorT; // For Epilogue
|
||||
// Use dummy values when we don't have scales or zeropoints
|
||||
using ElementZGroup =
|
||||
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
|
||||
using ElementSGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementConvertGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementSChannel =
|
||||
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
|
||||
using ElementSToken =
|
||||
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
|
||||
|
||||
using BTypeTuple = cute::conditional_t<
|
||||
with_group_scales,
|
||||
cute::conditional_t<with_group_zeropoints,
|
||||
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
|
||||
cute::tuple<ElementB, ElementSGroup>>,
|
||||
ElementB>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
// not actually used since B has the prepacked layout, but required by cutlass
|
||||
using _LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Interface strides expected by create_arguments (will get transposed)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZGroup = StrideSGroup;
|
||||
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using PrepackedLayoutB =
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||
|
||||
static int constexpr TileShapeK =
|
||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
|
||||
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
|
||||
static int constexpr AlignmentC =
|
||||
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
|
||||
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
|
||||
|
||||
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
|
||||
cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape;
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
||||
|
||||
static_assert(
|
||||
(!with_channel_scales && !with_token_scales) ||
|
||||
((with_channel_scales && with_token_scales) &&
|
||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename fastdeploy::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
"Currently token and channel scales (if present) must be float "
|
||||
"(and if one is present the other must be too)");
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using EVTCompute =
|
||||
std::conditional_t<with_channel_scales || with_token_scales,
|
||||
typename ChTokScalesEpilogue::EVTCompute,
|
||||
StoreEpilogueCompute>;
|
||||
|
||||
// EVTCompute
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::MacheteCollectiveBuilder<
|
||||
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
|
||||
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// stride_B is unused (since B is prepacked), but still required by cutlass
|
||||
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
|
||||
|
||||
using Arguments = typename Gemm::Arguments;
|
||||
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
||||
|
||||
static Arguments create_arguments(
|
||||
cudaStream_t stream,
|
||||
paddle::Tensor const& A, // MxK matrix
|
||||
paddle::Tensor const& B, // KxN prepacked matrix
|
||||
paddle::Tensor& D, // MxN matrix
|
||||
std::optional<paddle::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
std::optional<paddle::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<paddle::Tensor> const& maybe_ch_scales, // len N vector
|
||||
std::optional<paddle::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
int M = A.shape()[0], N = B.shape()[1], K = A.shape()[1];
|
||||
PD_CHECK(D.shape()[0] == M && D.shape()[1] == N);
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_S_group =
|
||||
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
|
||||
auto layout_Z_group =
|
||||
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
|
||||
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||
|
||||
auto unwrap = [](auto const& t) {
|
||||
return t ? t->data() : nullptr;
|
||||
};
|
||||
auto A_ptr = static_cast<ElementA const*>(A.data());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.data());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data());
|
||||
auto S_group_ptr =
|
||||
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
|
||||
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
|
||||
auto S_channel_ptr =
|
||||
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
|
||||
auto S_token_ptr =
|
||||
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
|
||||
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
int const scale_k = (K + group_size - 1) / group_size;
|
||||
|
||||
PD_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
PD_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
||||
|
||||
if constexpr (with_group_scales) {
|
||||
PD_CHECK(S_group_ptr && layout_S_group);
|
||||
PD_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||
size<1>(*layout_S_group) == N));
|
||||
} else {
|
||||
PD_CHECK(!S_group_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_group_zeropoints) {
|
||||
PD_CHECK(Z_group_ptr && layout_Z_group);
|
||||
PD_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||
size<1>(*layout_Z_group) == N));
|
||||
PD_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||
"Scales and zeros must have the same layout");
|
||||
} else {
|
||||
PD_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
PD_CHECK(
|
||||
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
|
||||
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
|
||||
}
|
||||
|
||||
// Transpose A and D
|
||||
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
|
||||
MainloopArguments mainloop_arguments{};
|
||||
// {Accum, C, C_layout, D, D}
|
||||
EpilogueArguments epilogue_arguments{};
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
epilogue_arguments =
|
||||
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
|
||||
*maybe_ch_scales, *maybe_tok_scales),
|
||||
nullptr,
|
||||
{},
|
||||
D_ptr,
|
||||
stride_Dt};
|
||||
} else {
|
||||
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
|
||||
}
|
||||
|
||||
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments = MainloopArguments{
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||
} else if constexpr (with_group_scales) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size};
|
||||
} else {
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||
}
|
||||
|
||||
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{N, M, K, 1},
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
};
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
return Gemm::get_workspace_size(args);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
|
||||
PD_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Machete kernel failed to initialize workspace");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
PD_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
78
custom_ops/gpu_ops/machete/machete_mm_launcher.cuh
Normal file
78
custom_ops/gpu_ops/machete/machete_mm_launcher.cuh
Normal file
@@ -0,0 +1,78 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
#include "utils/scalar_type.h"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct MMArgs {
|
||||
paddle::Tensor const& A;
|
||||
paddle::Tensor const& B;
|
||||
machete::ScalarType const& b_type;
|
||||
std::optional<paddle::DataType> const& maybe_out_type;
|
||||
std::optional<paddle::Tensor> const& maybe_group_scales;
|
||||
std::optional<paddle::Tensor> const& maybe_group_zeros;
|
||||
std::optional<int64_t> maybe_group_size;
|
||||
std::optional<paddle::Tensor> const& maybe_channel_scales;
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales;
|
||||
std::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
paddle::DataType a_type;
|
||||
machete::ScalarType b_type;
|
||||
std::optional<paddle::DataType> maybe_group_scales_type;
|
||||
std::optional<paddle::DataType> maybe_group_zeros_type;
|
||||
std::optional<paddle::DataType> maybe_channel_scales_type;
|
||||
std::optional<paddle::DataType> maybe_token_scales_type;
|
||||
std::optional<paddle::DataType> maybe_out_type;
|
||||
};
|
||||
|
||||
paddle::Tensor mm_dispatch(MMArgs args);
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args);
|
||||
|
||||
template <typename MacheteKernel>
|
||||
paddle::Tensor run_impl(MMArgs args) {
|
||||
// const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
||||
|
||||
// auto device = args.A.device();
|
||||
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
auto place = args.A.place();
|
||||
cudaStream_t stream = args.A.stream();
|
||||
|
||||
int M = args.A.shape()[0];
|
||||
int N = args.B.shape()[1];
|
||||
int K = args.A.shape()[1];
|
||||
|
||||
// Allocate output
|
||||
paddle::Tensor D = paddle::empty(
|
||||
{M, N},
|
||||
equivalent_scalar_type_v<typename MacheteKernel::ElementD>,
|
||||
place);
|
||||
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, //
|
||||
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
|
||||
args.maybe_group_size, args.maybe_channel_scales,
|
||||
args.maybe_token_scales);
|
||||
PD_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
|
||||
int S = static_cast<int>(workspace_size);
|
||||
// phi::Allocator* allocator = paddle::GetAllocator(place);
|
||||
// auto workspace = allocator->Allocate(workspace_size);
|
||||
// MacheteKernel::run(arguments, workspace->ptr(), stream);
|
||||
// paddle::Tensor workspace = paddle::empty({S}, paddle::DataType::UINT8, place);
|
||||
paddle::Tensor workspace = GetEmptyTensor({S}, paddle::DataType::UINT8, place);
|
||||
MacheteKernel::run(arguments, workspace.data(), stream);
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
73
custom_ops/gpu_ops/machete/machete_prepack_B.cu
Normal file
73
custom_ops/gpu_ops/machete/machete_prepack_B.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
|
||||
paddle::Tensor prepack_B(
|
||||
paddle::Tensor const& B, paddle::DataType const& a_type, int64_t b_type_id,
|
||||
std::string const& maybe_group_scales_type_str) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<paddle::DataType> maybe_group_scales_type;
|
||||
if (maybe_group_scales_type_str == "float16") {
|
||||
maybe_group_scales_type = paddle::DataType::FLOAT16;
|
||||
}
|
||||
else if (maybe_group_scales_type_str == "bfloat16") {
|
||||
maybe_group_scales_type = paddle::DataType::BFLOAT16;
|
||||
}
|
||||
else if (maybe_group_scales_type_str == "float32") {
|
||||
maybe_group_scales_type = paddle::DataType::FLOAT32;
|
||||
}
|
||||
else if (maybe_group_scales_type_str == "") {
|
||||
maybe_group_scales_type = std::nullopt;
|
||||
}
|
||||
else {
|
||||
PADDLE_ENFORCE(false, "maybe_group_scales_type_str not supported!");
|
||||
}
|
||||
return machete::prepack_B_dispatch(
|
||||
{.B = B,
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type});
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str) {
|
||||
|
||||
machete::ScalarTypeId b_type_id;
|
||||
paddle::DataType a_type, maybe_group_scales_type;
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
||||
if (a_type_str == "float16") {
|
||||
a_type = paddle::DataType::FLOAT16;
|
||||
}
|
||||
else if (a_type_str == "bfloat16") {
|
||||
a_type = paddle::DataType::BFLOAT16;
|
||||
}
|
||||
else {
|
||||
PADDLE_ENFORCE(false, "a_type_str not supported!");
|
||||
}
|
||||
auto Bt = paddle::experimental::transpose(B, {1, 0});
|
||||
paddle::Tensor B_prepacked = prepack_B(Bt, a_type, b_type_id, maybe_group_scales_type_str);
|
||||
return {B_prepacked};
|
||||
|
||||
}
|
||||
76
custom_ops/gpu_ops/machete/machete_prepack_kernel.cuh
Normal file
76
custom_ops/gpu_ops/machete/machete_prepack_kernel.cuh
Normal file
@@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "utils/cute_utils.cuh"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
template <int threads, typename PrepackedLayoutB, typename BInTensor,
|
||||
typename ElementB>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
|
||||
auto constexpr block_size =
|
||||
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
|
||||
auto constexpr eles_per_thread = Int<block_size / threads>{};
|
||||
static_assert(block_size % threads == 0,
|
||||
"block_size must be divisible by the number of threads");
|
||||
|
||||
// Which pre-packed are we responsible for
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
auto tB_in = local_tile(
|
||||
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
|
||||
blk_coord);
|
||||
|
||||
// Find the start offset in the output for this pre-packed block
|
||||
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
|
||||
|
||||
// Tensor representing a 1:1 mapping to the output space in 1D
|
||||
auto tB_out_linear =
|
||||
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
|
||||
make_layout(make_shape(block_size)));
|
||||
// Mapping from output space (1D) to input space
|
||||
auto tB_in_linear = make_tensor(
|
||||
tB_in.data(),
|
||||
tB_in.layout()
|
||||
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
|
||||
.with_shape(make_shape(block_size)));
|
||||
|
||||
// Tile for this specific thread (could have used a TiledCopy but these work
|
||||
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
|
||||
// we are also not that concerned with performance for this kernel)
|
||||
auto thr_tB_in_linear =
|
||||
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
auto thr_tB_out_linear =
|
||||
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's
|
||||
// partition
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
|
||||
|
||||
copy(thr_tB_in_linear, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
|
||||
}
|
||||
|
||||
template <typename PrepackedLayoutB, typename InLayout>
|
||||
static void prepack_B_template(
|
||||
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
||||
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
||||
using TileShapeNKL =
|
||||
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
|
||||
auto ilvd_NKbNbKL_to_offset =
|
||||
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
|
||||
|
||||
PD_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
||||
PD_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
||||
|
||||
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
||||
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout);
|
||||
|
||||
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
||||
|
||||
prepack_B_kernel<128, PrepackedLayoutB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
77
custom_ops/gpu_ops/machete/machete_prepack_launcher.cuh
Normal file
77
custom_ops/gpu_ops/machete/machete_prepack_launcher.cuh
Normal file
@@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include "machete_prepack_kernel.cuh"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
#include "utils/scalar_type.h"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PrepackBArgs {
|
||||
paddle::Tensor const& B;
|
||||
paddle::DataType a_type;
|
||||
machete::ScalarType b_type;
|
||||
std::optional<paddle::DataType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
paddle::Tensor prepack_impl(paddle::Tensor const B) {
|
||||
// const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
||||
using ElementB = typename PrepackedLayoutB::ElementB;
|
||||
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
|
||||
|
||||
// auto device = B.device();
|
||||
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
cudaStream_t stream = B.stream();
|
||||
auto B_ptr = static_cast<ElementB const*>(B.data());
|
||||
// elements per storage item for B
|
||||
auto eles_per_storage =
|
||||
(SizeOf(B.dtype()) * 8) / cute::sizeof_bits_v<ElementB>;
|
||||
|
||||
// paddle B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
|
||||
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
|
||||
// auto Bt_packed = B.transpose();
|
||||
auto Bt_packed = paddle::experimental::transpose(B, {1, 0});
|
||||
|
||||
PD_CHECK(
|
||||
(B.shape()[0] * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
|
||||
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
|
||||
size<1>(PPBlockShape_NK{}));
|
||||
PD_CHECK(B.shape()[1] % size<0>(PPBlockShape_NK{}) == 0,
|
||||
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
|
||||
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
|
||||
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
|
||||
// auto const l_Bt_packed = make_cute_layout<StrideB>(B, "B");
|
||||
|
||||
// convert (N,packed_K,L) layout to (N,K,L) layout
|
||||
// in effect we want to do: blocked_product(layout_Bt_packed,
|
||||
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
|
||||
// Step<_1, _0, _2>{}));
|
||||
// but blocked_product does not support dynamic strides so we implement the
|
||||
// equivalent manually,
|
||||
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
|
||||
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
|
||||
// when s1 == 1
|
||||
PD_CHECK(stride<1>(l_Bt_packed) == 1, "stride<1>(l_Bt_packed) must be 1");
|
||||
// clang-format off
|
||||
auto const layout_Bt = make_layout(
|
||||
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
|
||||
return idx == 1 ? ele * eles_per_storage : ele;
|
||||
}),
|
||||
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
|
||||
return idx != 1 ? ele * eles_per_storage : ele;
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
// Allocate output
|
||||
paddle::Tensor D = paddle::empty_like(B);
|
||||
|
||||
prepack_B_template<PrepackedLayoutB>(
|
||||
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.data()));
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
paddle::Tensor prepack_B_dispatch(PrepackBArgs args);
|
||||
|
||||
}; // namespace machete
|
||||
249
custom_ops/gpu_ops/machete/machete_prepacked_layout.cuh
Normal file
249
custom_ops/gpu_ops/machete/machete_prepacked_layout.cuh
Normal file
@@ -0,0 +1,249 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
// The cutlass include order matters (annoyingly)
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
// clang-format on
|
||||
|
||||
#include "utils/cute_utils.cuh"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct IlvBlkLayoutAuto {};
|
||||
|
||||
// This defines a prepacked layout for the B matrix, where the matrix is broken
|
||||
// up into PPBlockShape_NK blocks. The data within each block is then compactly
|
||||
// stored in memory such that when performing a TiledMMA operation with the same
|
||||
// shape as prepacked block, all the data for a given thread is contiguous in
|
||||
// memory. This allows us to use wider shared memory loads when loading B from
|
||||
// shared memory. The values within a thread are also potentially interlaeved
|
||||
// inorder to allow for more efficient upconverting.
|
||||
//
|
||||
// The contract here is that the `TiledMma` determined below matches the one
|
||||
// ultimately used in the kernel. (this is also why the other element types are
|
||||
// required along with the kernel schedule)
|
||||
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||
// clang-format on
|
||||
struct PrepackedLayoutBTemplate {
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementMma = MmaType;
|
||||
|
||||
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
||||
// in those cases case we use a LUT using prmt instructions to upconvert and
|
||||
// is more efficient if the data is not interleaved For 8bit+ prmt
|
||||
// instructions makes non-interleaved layouts efficient enough we don't need
|
||||
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
||||
static constexpr bool should_interleave =
|
||||
sizeof_bits_v<ElementB> <= 4 &&
|
||||
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
||||
!std::is_same_v<ElementConvert_, int8_t>;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights,
|
||||
using IlvdBlkLayout = std::conditional_t<
|
||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||
std::conditional_t<
|
||||
should_interleave,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||
void>,
|
||||
IlvBlkLayout_>;
|
||||
|
||||
// TODO (LucasWilkinson): compare the performance for other sizes
|
||||
// Prepacked block shape, smallest layout atom for loading into registers
|
||||
// (can contain multiple wgmma instructions worth of data in one block)
|
||||
// We ideally want this to be configured such that a thread can perform 128bit
|
||||
// loads, i.e. we amount of data associated with each thread within a
|
||||
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
|
||||
// we have 256 threads working a single block at a time, this means each
|
||||
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
|
||||
// for a 4bit type this would be 128bits
|
||||
using PPBlockShape_NK = Shape<_128, _64>;
|
||||
|
||||
// Create the shape of the tile anticipated to be used by the GEMM kernel,
|
||||
// when the kernel executes we will compute `Ct = Bt * At` since the
|
||||
// quantized weights (B), must be the lhs operand so the flow through
|
||||
// registers.
|
||||
// The _128 here doesn't actually impact the shape of the stored tile directly
|
||||
// but may impact the op selected by rs_op_selector
|
||||
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
|
||||
size<1>(PPBlockShape_NK{})));
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
gmma_rs_tag_to_major_B<LayoutB>();
|
||||
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
|
||||
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
// Prepacked block, (athrid, val) -> (N,K)
|
||||
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
|
||||
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
|
||||
}
|
||||
|
||||
// Prepacked block, (N,K) -> (athrid, val)
|
||||
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
|
||||
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
|
||||
}
|
||||
|
||||
// Prepacked block, (athrid, val) -> (storage_offset)
|
||||
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
|
||||
// Return iterleaved layout
|
||||
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
||||
}
|
||||
|
||||
// Prepacked block, (athrid, val) -> (storage_offset)
|
||||
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
|
||||
auto layout_no_interleave =
|
||||
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
||||
|
||||
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
|
||||
return layout_no_interleave;
|
||||
} else {
|
||||
// interleave by transforming FrgV into interleaved blocks where each
|
||||
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
|
||||
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
|
||||
// if FrgV is {A, B, C, D, E, F, G, H}
|
||||
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
||||
auto frgV = get<1, 0>(layout_no_interleave);
|
||||
auto ilvdBlk = IlvdBlkLayout{};
|
||||
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
||||
"FrgV must be divisible by size(ilvdBlk)");
|
||||
auto ilvd_FrgV = make_layout(
|
||||
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
||||
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
||||
|
||||
// Return iterleaved layout
|
||||
return make_layout(
|
||||
get<0>(layout_no_interleave),
|
||||
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
|
||||
}
|
||||
}
|
||||
|
||||
// Prepacked block, (M,K) -> (storage_offset)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
|
||||
// do (M,K) -> (athrid, val) -> (storage_idx)
|
||||
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
|
||||
Shape_NKL shape_mkl) {
|
||||
constexpr auto block_layout = ppblock_TV_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
block_layout,
|
||||
make_layout(blocks_shape,
|
||||
compact_col_major(blocks_shape, size(block_layout))));
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L))
|
||||
// => ((athrid, val), (BlocksN, BlocksK), L)
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
||||
Shape_NKL shape_mkl) {
|
||||
auto layout = TVbNbKL_to_offset(shape_mkl);
|
||||
// for 4-bit elements, having >= 64 values per column
|
||||
// allows TMA to load full 32-byte sectors
|
||||
auto inner_layout =
|
||||
make_layout(make_shape(_256{}, size<0>(layout) / _256{}));
|
||||
|
||||
return make_layout(inner_layout, get<1>(layout), get<2>(layout));
|
||||
}
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
||||
Shape_NKL shape_mkl) {
|
||||
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
block_layout,
|
||||
make_layout(blocks_shape,
|
||||
compact_col_major(blocks_shape, size(block_layout))));
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
|
||||
// BlocksK), L)
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto stride = size(PPBlockShape_NK{});
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
||||
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
|
||||
make_layout(size<1>(PPBlockShape_NK{})));
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
|
||||
return tiled_A.compose(ppblock_TV_to_NK(), _);
|
||||
}
|
||||
|
||||
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
|
||||
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
|
||||
return blocked_product(ppblock_NK_to_TV(),
|
||||
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
72
custom_ops/gpu_ops/machete/machete_supported_schedules.cu
Normal file
72
custom_ops/gpu_ops/machete/machete_supported_schedules.cu
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
|
||||
template <typename T>
|
||||
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
|
||||
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
paddle::DataType a_type, int64_t b_type_id,
|
||||
std::optional<paddle::DataType> maybe_group_scales_type,
|
||||
std::optional<paddle::DataType> maybe_group_zeros_type,
|
||||
std::optional<paddle::DataType> maybe_channel_scales_type,
|
||||
std::optional<paddle::DataType> maybe_token_scales_type,
|
||||
std::optional<paddle::DataType> maybe_out_type) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
auto schedules = machete::supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type,
|
||||
.maybe_group_zeros_type = maybe_group_zeros_type,
|
||||
.maybe_channel_scales_type = maybe_channel_scales_type,
|
||||
.maybe_token_scales_type = maybe_token_scales_type,
|
||||
.maybe_out_type = maybe_out_type
|
||||
});
|
||||
return schedules;
|
||||
}
|
||||
|
||||
std::vector<std::string> MacheteSupportedSchedules(
|
||||
std::string const& a_type_str, std::string const& b_type_str) {
|
||||
machete::ScalarTypeId b_type_id;
|
||||
paddle::DataType a_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
if (a_type_str == "bfloat16") {
|
||||
a_type = paddle::DataType::BFLOAT16;
|
||||
} else if (a_type_str == "float16") {
|
||||
a_type = paddle::DataType::FLOAT16;
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "a_type_str not supported!");
|
||||
}
|
||||
std::optional<paddle::DataType> maybe_group_scales_type = std::optional<paddle::DataType>(a_type);
|
||||
std::optional<paddle::DataType> maybe_out_type = std::optional<paddle::DataType>(a_type);
|
||||
std::optional<paddle::DataType> maybe_group_zeros_type = std::nullopt;
|
||||
std::optional<paddle::DataType> maybe_channel_scales_type = std::nullopt;
|
||||
std::optional<paddle::DataType> maybe_token_scales_type = std::nullopt;
|
||||
|
||||
auto schedules = supported_schedules(a_type, b_type_id,
|
||||
maybe_group_scales_type,
|
||||
maybe_group_zeros_type,
|
||||
maybe_channel_scales_type,
|
||||
maybe_token_scales_type,
|
||||
maybe_out_type);
|
||||
return schedules;
|
||||
}
|
||||
69
custom_ops/gpu_ops/machete/utils/cute_utils.cuh
Normal file
69
custom_ops/gpu_ops/machete/utils/cute_utils.cuh
Normal file
@@ -0,0 +1,69 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/cute_utils.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// layout utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Permute layout based on indices, example:
|
||||
// permute_layout<1, 0>(layout) will swap the two dimensions
|
||||
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
|
||||
template <size_t... I, typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
||||
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
|
||||
return cute::make_layout(cute::get<I>(l)...);
|
||||
}
|
||||
|
||||
// is the layout f(x) = x
|
||||
template <typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||
if constexpr (std::is_same_v<Layout, void>) {
|
||||
return true;
|
||||
} else {
|
||||
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||
if constexpr (rank(coalesced_layout) == 1 &&
|
||||
stride<0>(coalesced_layout) == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Pointer utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class PointerType>
|
||||
static constexpr auto get_logical_ptr(PointerType* ptr) {
|
||||
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
|
||||
return cute::subbyte_iterator<PointerType>(ptr);
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Misc utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Elements>
|
||||
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
|
||||
constexpr auto bits = sizeof_bits_v<T> * Elements{};
|
||||
if constexpr (bits % 128 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||
} else if constexpr (bits % 64 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<64>{};
|
||||
} else if constexpr (bits % 32 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<32>{};
|
||||
} else if constexpr (bits % 16 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<16>{};
|
||||
} else {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<8>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace cute
|
||||
@@ -0,0 +1,44 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_collective_builder.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
//
|
||||
// MacheteCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
|
||||
// for custom kernel tags, allowing you to build custom collectives. Without
|
||||
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
|
||||
// will resort to using the standard cutlass collective builder.
|
||||
//
|
||||
|
||||
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
|
||||
// collective
|
||||
struct CutlassKernelTag {};
|
||||
|
||||
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
|
||||
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
|
||||
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
|
||||
class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, class Enable = void>
|
||||
struct MacheteCollectiveBuilder {
|
||||
static_assert(sizeof(ElementA) == 0,
|
||||
"Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct MacheteCollectiveBuilder<
|
||||
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType> {
|
||||
using CollectiveOp = typename CollectiveBuilder<
|
||||
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
51
custom_ops/gpu_ops/machete/utils/machete_custom_types.cuh
Normal file
51
custom_ops/gpu_ops/machete/utils/machete_custom_types.cuh
Normal file
@@ -0,0 +1,51 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_custom_types.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed = false>
|
||||
struct machete_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
|
||||
using Base = integer_subbyte<Bits, Signed>;
|
||||
|
||||
using Storage = typename Base::Storage;
|
||||
using xint_t = typename Base::xint_t;
|
||||
|
||||
using Base::bits_mask_;
|
||||
using Base::sign_mask_;
|
||||
using Base::storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// No operation
|
||||
machete_biased_integer_subbyte() = default;
|
||||
|
||||
/// Conversion from integer type
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(int value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(unsigned value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(double value)
|
||||
: Base(value) {}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// "GPTQ" types, i.e. symmetric quantization
|
||||
using machete_uint4b8_t = machete_biased_integer_subbyte<4, 8>; // u4b8
|
||||
using machete_uint8b128_t = machete_biased_integer_subbyte<8, 128>; // u8b128
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed>
|
||||
struct sizeof_bits<machete_biased_integer_subbyte<Bits, Bias, Signed>> {
|
||||
static constexpr int value = Bits;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
993
custom_ops/gpu_ops/machete/utils/machete_numeric_conversion.cuh
Normal file
993
custom_ops/gpu_ops/machete/utils/machete_numeric_conversion.cuh
Normal file
@@ -0,0 +1,993 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "machete_custom_types.cuh"
|
||||
#include "cute_utils.cuh"
|
||||
#include "machete_type_utils.cuh"
|
||||
|
||||
// this file extends:
|
||||
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||
// with vllm specific type conversions, namely: machete_uint4b8_t, machete_uint8b128_t
|
||||
// as well as adds interleaved numeric array converters for specific types.
|
||||
// (interleaved numeric array converters can be more efficient for subbyte
|
||||
// types)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
|
||||
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
|
||||
// make subbyte converts more efficient by allowing for efficient extraction
|
||||
// of subbyte elements from a 32bit register.
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||
class Enable = void>
|
||||
struct InterleavedNumericArrayConverter {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N);
|
||||
} else {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
|
||||
"implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
|
||||
}
|
||||
__brkpt();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round>
|
||||
struct InterleavedNumericArrayConverter<
|
||||
IlvBlkLayout, T, S, N, Round,
|
||||
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return Converter::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||
struct ArrayConverterPacked32Bit {
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<S, N>;
|
||||
|
||||
using result_packed_8_t = Array<T, 8>;
|
||||
using result_packed_4_t = Array<T, 4>;
|
||||
using result_packed_2_t = Array<T, 2>;
|
||||
using src_packed_8_t = Array<S, 8>;
|
||||
using src_packed_4_t = Array<S, 4>;
|
||||
using src_packed_2_t = Array<S, 2>;
|
||||
|
||||
static_assert(N % 2 == 0, "N must be a multiple of 2");
|
||||
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
|
||||
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
|
||||
static constexpr auto src_elems_per_32bit_reg =
|
||||
32 / cutlass::sizeof_bits_v<S>;
|
||||
|
||||
// Maybe not Valid. ScalarConverter will not actually work unless
|
||||
// NumericConverter<T, S, Round> is implemented. However it won't be used
|
||||
// anyways since we assert N % 2 == 0, just here for compliance with
|
||||
// VectorizedConverter.
|
||||
using ScalarConverter = NumericConverter<T, S>;
|
||||
|
||||
template <typename PackedSrc>
|
||||
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
|
||||
if constexpr (sizeof(PackedSrc) == 1) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 4) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
|
||||
} else {
|
||||
static_assert(sizeof(PackedSrc) == 8);
|
||||
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses bit tricks to construct a known FP16 number, then
|
||||
// does a subtraction in FP16 for the final result.
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
|
||||
static_assert(PackedResultType::kElements == 2 ||
|
||||
PackedResultType::kElements == 4 ||
|
||||
PackedResultType::kElements == 8,
|
||||
"Invalid PackedResultType must be 2, 4 or 8.");
|
||||
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||
|
||||
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
ArrayConverterPacked32Bit<RegConvert32bit,
|
||||
typename result_type::Element,
|
||||
typename source_type::Element, N>;
|
||||
|
||||
if constexpr (src_elems_per_32bit_reg >= 8) {
|
||||
detail::VectorizedConverter::convert<
|
||||
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
|
||||
} else if constexpr (src_elems_per_32bit_reg >= 4) {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
} else {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
|
||||
// into 2 32bit register.
|
||||
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
|
||||
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
|
||||
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
|
||||
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
|
||||
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
|
||||
uint32_t src) {
|
||||
cutlass::AlignedArray<uint32_t, 2> r;
|
||||
// Determines if the value is in the top half of the LUT if set or
|
||||
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
|
||||
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
|
||||
// selects the correct candidate. When elements in final_prmt_base
|
||||
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
|
||||
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
|
||||
uint32_t high_bit = (src & 0x88888888) >> 1;
|
||||
|
||||
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
|
||||
// (selects correct high or low candidate)
|
||||
const uint32_t final_prmt_base = 0x32103210;
|
||||
|
||||
// Ignore the high bit when indexing into LUT, for each 4bit value
|
||||
// we index into both the high and low candidates then use
|
||||
// high_bit | final_prmt_base to select the correct candidate
|
||||
uint32_t lut_idx = (src & 0x77777777);
|
||||
|
||||
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
|
||||
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
|
||||
(uint32_t(d) << 24);
|
||||
};
|
||||
|
||||
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
|
||||
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
|
||||
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
|
||||
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
|
||||
uint32_t final_prmt_idx = final_prmt_base | high_bit;
|
||||
|
||||
// This uses a look up table to convert packed int4s to packed int8s,
|
||||
// using the int4 value as the index to prmt. It first select both the
|
||||
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
|
||||
// select the correct candidate.
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .b32 low, high;\n"
|
||||
" prmt.b32 low, %1, %2, %5;\n"
|
||||
" prmt.b32 high, %3, %4, %5;\n"
|
||||
" prmt.b32 %0, low, high, %6;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
|
||||
"r"(final_prmt_idx));
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
|
||||
// for Array<int8_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
|
||||
0xFC, 0xFD, 0xFE, 0xFF, //
|
||||
0x00, 0x01, 0x02, 0x03, //
|
||||
0x04, 0x05, 0x06, 0x07>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float_e4m3_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::float_e4m3_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::float_e4m3_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
|
||||
0xC8, 0xC4, 0xC0, 0xB8, //
|
||||
0x00, 0x38, 0x40, 0x44, //
|
||||
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
|
||||
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
|
||||
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
|
||||
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want
|
||||
// the documented (& 0x7) on the index. NVCC might be able to optimize it
|
||||
// out since the index is a constexpr, but we choose to be safe about it
|
||||
// here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for F16 -> I4 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a fp16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the FP16 to the correct value for the
|
||||
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
|
||||
// where x1 in the high nibble and x0 is the low nibble then using hfma
|
||||
// to subtract 1032 from that
|
||||
// The AND does the following:
|
||||
// 1) Clear the set bits for the int4 we will ignore.
|
||||
// We use lop3 so that we can use 1 instruction for AND and XOR.
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
static constexpr uint32_t and_mask = 0xFFF0FF0F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 hfmas that do the following:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
|
||||
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
|
||||
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
|
||||
|
||||
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
|
||||
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, machete_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
|
||||
// - {72, 72}
|
||||
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
|
||||
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(start_byte_for_fp16),
|
||||
"r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
|
||||
static constexpr uint32_t bias_rep = 0x64806480;
|
||||
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hsub2(fp16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<float, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<float, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
PackedResultType r;
|
||||
|
||||
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||
// u8x4 source and stores the result in r (without introducing extra
|
||||
// cvt.u32.u8 instruction)
|
||||
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
|
||||
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
|
||||
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
|
||||
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
|
||||
// Subtract the magic number 0x4B000000 from tmp in floating-point
|
||||
// arithmetic to obtain final result
|
||||
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src_reg = src_[0];
|
||||
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for uint4b8_t -> BF16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a BF16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the BF16 to the correct value for the
|
||||
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
|
||||
// and subtracting 136 to get {x1, x0}
|
||||
static constexpr uint32_t xor_mask = 0x43004300;
|
||||
static constexpr uint32_t and_mask = 0x000F000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 bfmas that do the following:
|
||||
// high BF16:
|
||||
// hi_bf16 - 136, lo_bf16 - 136
|
||||
|
||||
// This is the BF16 {136, 136} represented as an integer.
|
||||
static constexpr uint32_t bias_rep = 0x43084308;
|
||||
const __nv_bfloat162& bias =
|
||||
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, machete_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
|
||||
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
|
||||
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
|
||||
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
|
||||
using src_packed_4_t = Array<machete_uint8b128_t, 4>;
|
||||
using src_packed_2_t = Array<machete_uint8b128_t, 2>;
|
||||
|
||||
// Not Valid, not supported, only here to satisfy the interface and to avoid
|
||||
// a compile error. ScalarConverter will not actually work until
|
||||
// NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round> is
|
||||
// implemented
|
||||
using ScalarConverter =
|
||||
NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round>;
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(
|
||||
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
|
||||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_4_t>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
|
||||
"convert dispatch.");
|
||||
|
||||
NumericArrayConverter<float, machete_uint8b128_t, PackedResultType::kElements,
|
||||
Round>
|
||||
convert_uint8_to_f32;
|
||||
Array<float, PackedResultType::kElements> tmp =
|
||||
convert_uint8_to_f32(source);
|
||||
NumericArrayConverter<cutlass::bfloat16_t, float,
|
||||
PackedResultType::kElements, Round>
|
||||
convert_f32_to_bf16_;
|
||||
return convert_f32_to_bf16_(tmp);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
NumericArrayConverter<typename result_type::Element,
|
||||
typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<cutlass::half_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <typename PackedResultType, int src_regs>
|
||||
CUTLASS_DEVICE static PackedResultType convert(
|
||||
Array<uint32_t, src_regs> src) {
|
||||
// Hold output int8s in reg. We need 1 reg for every 4 elements
|
||||
using RegArray = cutlass::AlignedArray<
|
||||
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
|
||||
RegArray r;
|
||||
|
||||
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
|
||||
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
|
||||
|
||||
*reinterpret_cast<half2*>(&src[0]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
|
||||
|
||||
if constexpr (src_regs > 1) {
|
||||
*reinterpret_cast<half2*>(&src[1]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
|
||||
}
|
||||
|
||||
static_assert(PackedResultType::kElements <= 4);
|
||||
uint32_t uint8s;
|
||||
static constexpr uint32_t MASK_0246 = 0x6420;
|
||||
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(uint8s)
|
||||
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
|
||||
"n"(MASK_0246));
|
||||
|
||||
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(int8s);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
43
custom_ops/gpu_ops/machete/utils/machete_type_utils.cuh
Normal file
43
custom_ops/gpu_ops/machete/utils/machete_type_utils.cuh
Normal file
@@ -0,0 +1,43 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cuda_bf16.h"
|
||||
|
||||
#include "machete_custom_types.cuh"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename T>
|
||||
struct nameof {
|
||||
static constexpr char const* value = "unknown";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr auto nameof_v = nameof<T>::value;
|
||||
|
||||
#define NAMEOF_TYPE(T) \
|
||||
template <> \
|
||||
struct nameof<T> { \
|
||||
static constexpr char const* value = #T; \
|
||||
};
|
||||
|
||||
NAMEOF_TYPE(float_e4m3_t)
|
||||
NAMEOF_TYPE(float_e5m2_t)
|
||||
NAMEOF_TYPE(half_t)
|
||||
NAMEOF_TYPE(nv_bfloat16)
|
||||
NAMEOF_TYPE(bfloat16_t)
|
||||
NAMEOF_TYPE(float)
|
||||
|
||||
NAMEOF_TYPE(int4b_t)
|
||||
NAMEOF_TYPE(int8_t)
|
||||
NAMEOF_TYPE(int32_t)
|
||||
NAMEOF_TYPE(int64_t)
|
||||
|
||||
NAMEOF_TYPE(machete_uint4b8_t)
|
||||
NAMEOF_TYPE(uint4b_t)
|
||||
NAMEOF_TYPE(uint8_t)
|
||||
NAMEOF_TYPE(machete_uint8b128_t)
|
||||
NAMEOF_TYPE(uint32_t)
|
||||
NAMEOF_TYPE(uint64_t)
|
||||
|
||||
}; // namespace cutlass
|
||||
161
custom_ops/gpu_ops/machete/utils/paddle_utils.hpp
Normal file
161
custom_ops/gpu_ops/machete/utils/paddle_utils.hpp
Normal file
@@ -0,0 +1,161 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
|
||||
namespace cute {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
|
||||
seq<I...>) {
|
||||
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
|
||||
}
|
||||
|
||||
template <class F, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
|
||||
return make_shape(f(I)...);
|
||||
}
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
|
||||
if constexpr (cute::is_tuple<T>::value) {
|
||||
return detail::tapply_with_idx(
|
||||
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// calls: make_shape(f(0), f(1), ..., f(N-1))
|
||||
template <int N, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
return detail::make_shape_from_idx(f, make_seq<N>{});
|
||||
}
|
||||
|
||||
}; // namespace cute
|
||||
|
||||
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
|
||||
// shape of the passed in tensor and the strides are of type `Stride` and
|
||||
// contain the strides of the passed in tensor, checking that any static strides
|
||||
// in `Stride{}` match the strides of the passed in tensor.
|
||||
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(paddle::Tensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
PD_CHECK(tensor.shape().size() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(
|
||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
|
||||
if (idx < tensor.shape().size()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
|
||||
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.shape()[idx] == 1) {
|
||||
// use 0 stride for dims with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.strides()[idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
}
|
||||
});
|
||||
|
||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||
if (idx < tensor.shape().size())
|
||||
return tensor.shape()[idx];
|
||||
else
|
||||
return int64_t(1);
|
||||
});
|
||||
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
std::optional<paddle::Tensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
if (tensor) {
|
||||
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
|
||||
} else {
|
||||
return std::optional<Layout>{};
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Paddle dtype to Cutlass Type (equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
template <typename T>
|
||||
struct equivalent_cutlass_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
//
|
||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||
template <typename T>
|
||||
struct equivalent_scalar_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::half_t> {
|
||||
using type = phi::dtype::float16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||
using type = phi::dtype::bfloat16;
|
||||
};
|
||||
|
||||
// get equivalent c10::ScalarType tag from compile time type
|
||||
template <typename T>
|
||||
static inline constexpr paddle::DataType equivalent_scalar_type_v =
|
||||
phi::CppTypeToDataType<equivalent_scalar_type_t<T>>::Type();
|
||||
372
custom_ops/gpu_ops/machete/utils/scalar_type.h
Normal file
372
custom_ops/gpu_ops/machete/utils/scalar_type.h
Normal file
@@ -0,0 +1,372 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/phi/common/data_type.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
namespace machete {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : uint8_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||
int32_t bias, bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits, false, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
||||
uint8_t mantissa) {
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
|
||||
// "use `float_IEEE754` constructor for floating point types that "
|
||||
// "follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||
nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
int32_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
private:
|
||||
// Field size in id
|
||||
template <typename T_>
|
||||
static constexpr size_t member_id_field_width() {
|
||||
using T = std::decay_t<T_>;
|
||||
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||
Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
return reduce_members_helper(f, new_val, rest...);
|
||||
} else {
|
||||
return new_val;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||
finite_values_only, nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||
return dummy_type.reduce_members(f, init);
|
||||
};
|
||||
|
||||
static constexpr auto id_size_bits() {
|
||||
return reduce_member_types(
|
||||
[](int acc, auto member) -> int {
|
||||
return acc + member_id_field_width<decltype(member)>();
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
public:
|
||||
// unique id for this scalar type that can be computed at compile time for
|
||||
// c++17 template specialization this is not needed once we migrate to
|
||||
// c++20 and can pass literal classes as template parameters
|
||||
constexpr Id id() const {
|
||||
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
||||
"ScalarType id is too large to be stored");
|
||||
|
||||
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
||||
auto member) -> std::pair<Id, uint32_t> {
|
||||
auto [id, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
||||
<< bit_offset,
|
||||
bit_offset + bits};
|
||||
};
|
||||
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||
}
|
||||
|
||||
// create a ScalarType from an id, for c++17 template specialization,
|
||||
// this is not needed once we migrate to c++20 and can pass literal
|
||||
// classes as template parameters
|
||||
static constexpr ScalarType from_id(Id id) {
|
||||
auto extract_and_advance = [id](auto result, auto member) {
|
||||
using T = decltype(member);
|
||||
auto [tuple, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<T>();
|
||||
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
||||
((uint64_t(1) << bits) - 1));
|
||||
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||
};
|
||||
|
||||
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
||||
std::pair<std::tuple<>, int>{});
|
||||
return std::apply([](auto... args) { return ScalarType(args...); },
|
||||
tuple_args);
|
||||
}
|
||||
|
||||
constexpr int64_t size_bits() const {
|
||||
return mantissa + exponent + is_signed();
|
||||
}
|
||||
constexpr bool is_signed() const { return signed_; }
|
||||
constexpr bool is_integer() const { return exponent == 0; }
|
||||
constexpr bool is_floating_point() const { return exponent > 0; }
|
||||
constexpr bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false &&
|
||||
nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
constexpr bool has_nans() const {
|
||||
return is_floating_point() && nan_repr != NAN_NONE;
|
||||
}
|
||||
constexpr bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
constexpr bool has_bias() const { return bias != 0; }
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
PADDLE_ENFORCE(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double =
|
||||
max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw =
|
||||
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
// PADDLE_ENFORCE(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
||||
// "Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
// PADDLE_ENFORCE(is_signed(),
|
||||
// "We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
// PADDLE_ENFORCE(!is_signed() || size_bits() <= 64,
|
||||
// "Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> max() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> min() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
||||
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||
bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only &&
|
||||
nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = machete::ScalarType::Id;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = machete::ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = machete::ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = machete::ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = machete::ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = machete::ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = machete::ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f =
|
||||
machete::ScalarType::float_(2, 1, true, machete::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f =
|
||||
machete::ScalarType::float_(3, 2, true, machete::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
machete::ScalarType::float_(4, 3, true, machete::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = machete::ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = machete::ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = machete::ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// // Fixed width style names, generally following:
|
||||
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
constexpr auto kInt4 = kS4;
|
||||
constexpr auto kUint4 = kU4;
|
||||
constexpr auto kUint4b8 = kU4B8;
|
||||
constexpr auto kInt8 = kS8;
|
||||
constexpr auto kUint8 = kU8;
|
||||
constexpr auto kUint8b128 = kU8B128;
|
||||
constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
constexpr auto kHalf = kFE5M10;
|
||||
constexpr auto kFloat16 = kHalf;
|
||||
constexpr auto kFloat16Id = kFloat16.id();
|
||||
|
||||
constexpr auto kInt32 = phi::DataType::INT32;
|
||||
constexpr auto kInt64 = phi::DataType::INT64;
|
||||
constexpr auto kBool = phi::DataType::BOOL;
|
||||
constexpr auto kFloat8_e4m3fn = phi::DataType::FLOAT8_E4M3FN;
|
||||
constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
|
||||
constexpr auto kFloat32 = phi::DataType::FLOAT32;
|
||||
constexpr auto kByte = phi::DataType::INT8;
|
||||
|
||||
}; // namespace machete
|
||||
117
custom_ops/gpu_ops/merge_prefill_decode_output.cu
Normal file
117
custom_ops/gpu_ops/merge_prefill_decode_output.cu
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <int warps, typename T>
|
||||
__global__ void FillEncoderDecoderResKernel(
|
||||
T * encoder_res_data,
|
||||
T * decoder_res_data,
|
||||
const int * seq_lens_encoder,
|
||||
const int * seq_lens_decoder,
|
||||
const int * seq_lens_this_time,
|
||||
const int * cu_seq_q,
|
||||
const int head_num,
|
||||
const int head_dim) {
|
||||
|
||||
const int bidb = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidt = blockIdx.z * warps;
|
||||
const int tid = threadIdx.x;
|
||||
const int warp_id = tid / 32;
|
||||
const int land_id = tid % 32;
|
||||
const int token_id = bidt + warp_id;
|
||||
|
||||
const int seq_len_encoder = seq_lens_encoder[bidb];
|
||||
const int seq_len_decoder = seq_lens_decoder[bidb];
|
||||
const int seq_len_this_time = seq_lens_this_time[bidb];
|
||||
|
||||
if (seq_len_encoder > 0 || seq_len_decoder == 0 || token_id >= seq_len_this_time) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int load_idx = ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
|
||||
|
||||
*reinterpret_cast<float2*>(encoder_res_data + load_idx) = *reinterpret_cast<float2*>(decoder_res_data + load_idx);
|
||||
}
|
||||
|
||||
void MergePrefillDecodeOutput(
|
||||
const paddle::Tensor &encoder_res,
|
||||
const paddle::Tensor &decoder_res,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seq_q,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int max_token) {
|
||||
|
||||
if (head_dim != 128) {
|
||||
PD_THROW("Only supported head_dim = 128");
|
||||
}
|
||||
const int batch_size = seq_lens_encoder.shape()[0];
|
||||
constexpr int warps = 4;
|
||||
const int tokens_block = (max_token + warps - 1) / warps;
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = batch_size;
|
||||
grid_dims.y = head_num;
|
||||
grid_dims.z = tokens_block;
|
||||
|
||||
if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
FillEncoderDecoderResKernel<warps>
|
||||
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
|
||||
const_cast<T*>(encoder_res.data<T>()),
|
||||
const_cast<T*>(decoder_res.data<T>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
head_num,
|
||||
head_dim
|
||||
);
|
||||
} else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
FillEncoderDecoderResKernel<warps>
|
||||
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
|
||||
const_cast<T*>(encoder_res.data<T>()),
|
||||
const_cast<T*>(decoder_res.data<T>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
head_num,
|
||||
head_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(merge_prefill_decode_output)
|
||||
.Inputs({"encoder_res",
|
||||
"decoder_res",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seq_q"})
|
||||
.Outputs({"res"})
|
||||
.Attrs({"head_num: int",
|
||||
"head_dim: int",
|
||||
"max_token: int"})
|
||||
.SetInplaceMap({{"encoder_res", "res"}})
|
||||
.SetKernelFn(PD_KERNEL(MergePrefillDecodeOutput));
|
||||
330
custom_ops/gpu_ops/moba_attn/moba_attn.cu
Normal file
330
custom_ops/gpu_ops/moba_attn/moba_attn.cu
Normal file
@@ -0,0 +1,330 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn.h"
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MobaAttention(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& rope_sin_cos,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::optional<paddle::Tensor>& attn_gate_weight,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time,
|
||||
const int moba_encoder_top_k_left,
|
||||
const int moba_encoder_top_k_right,
|
||||
const int moba_use_encoder_seq_limit,
|
||||
const int moba_decoder_top_k_left,
|
||||
const int moba_decoder_top_k_right,
|
||||
const int moba_use_decoder_seq_limit,
|
||||
const bool moba_use_mlp,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
paddle::Tensor out = paddle::empty({qkv.dims()[0], head_num * head_dim}, qkv.dtype(), qkv.place());
|
||||
if (max_dec_len_this_time > 0) {
|
||||
MobaDecoderAttnWriteCacheKv(
|
||||
qkv,
|
||||
q_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
rope_sin_cos,
|
||||
k_block_means,
|
||||
qkv_bias,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
cache_quant_type_str);
|
||||
|
||||
auto qk_gate_weight = MobaQKGemm(
|
||||
q_input,
|
||||
k_block_means,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_dec_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
true,
|
||||
moba_use_decoder_seq_limit
|
||||
)[0];
|
||||
|
||||
auto qk_gate_topk_idx = QkSortDecoder(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
moba_decoder_top_k_left,
|
||||
moba_decoder_top_k_right,
|
||||
moba_use_decoder_seq_limit
|
||||
)[0];
|
||||
|
||||
MobaDecoderAttn(
|
||||
q_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
k_block_means,
|
||||
out,
|
||||
qk_gate_topk_idx,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
moba_use_decoder_seq_limit,
|
||||
max_dec_len_this_time,
|
||||
max_dec_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
}
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
FusedBlockMeanAndRope(
|
||||
qkv,
|
||||
k_block_means,
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
rope_sin_cos,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
qkv_bias,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
|
||||
MobaEncoderAttnWriteCacheKv(
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_enc_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
|
||||
GetKVFromCache(
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
cache_quant_type_str
|
||||
);
|
||||
|
||||
paddle::Tensor *k_gate_weight = const_cast<paddle::Tensor*>(&k_block_means);
|
||||
|
||||
if (moba_use_mlp && attn_gate_weight) {
|
||||
paddle::Tensor k_gate_mlp = MobaMlpEinsum(
|
||||
k_input,
|
||||
attn_gate_weight.get(),
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_k,
|
||||
max_seq_len,
|
||||
kv_head_num
|
||||
)[0];
|
||||
k_gate_weight = &k_gate_mlp;
|
||||
}
|
||||
|
||||
auto qk_gate_weight = MobaQKGemm(
|
||||
q_input,
|
||||
*k_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
false,
|
||||
moba_use_encoder_seq_limit
|
||||
)[0];
|
||||
|
||||
|
||||
auto qk_gate_topk_idx = QkSortEncoder(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
q_pack_tokens,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
moba_encoder_top_k_left,
|
||||
moba_encoder_top_k_right,
|
||||
moba_use_mlp && !attn_gate_weight ? max_seq_len : moba_use_encoder_seq_limit)[0];
|
||||
|
||||
MobaEncoderAttn(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
qk_gate_topk_idx,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
out,
|
||||
max_enc_len_this_time,
|
||||
max_enc_len_this_time + max_dec_len_this_time,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len
|
||||
);
|
||||
}
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moba_attention)
|
||||
.Inputs({
|
||||
"qkv",
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"cu_seq_q_pack",
|
||||
"q_pack_tokens",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"block_tables",
|
||||
"rope_sin_cos",
|
||||
"k_block_means",
|
||||
paddle::Optional("attn_gate_weight"),
|
||||
paddle::Optional("qkv_bias"),
|
||||
paddle::Optional("cache_k_quant_scale"),
|
||||
paddle::Optional("cache_v_quant_scale"),
|
||||
paddle::Optional("cache_k_dequant_scale"),
|
||||
paddle::Optional("cache_v_dequant_scale"),
|
||||
paddle::Optional("cache_k_zero_points"),
|
||||
paddle::Optional("cache_v_zero_points")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_len: int",
|
||||
"max_enc_len_this_time: int",
|
||||
"max_dec_len_this_time: int",
|
||||
"moba_encoder_top_k_left: int",
|
||||
"moba_encoder_top_k_right: int",
|
||||
"moba_use_encoder_seq_limit: int",
|
||||
"moba_decoder_top_k_left: int",
|
||||
"moba_decoder_top_k_right: int",
|
||||
"moba_use_decoder_seq_limit: int",
|
||||
"moba_use_mlp: bool",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({
|
||||
"out",
|
||||
"q_input_out",
|
||||
"k_input_out",
|
||||
"v_input_out",
|
||||
"key_cache_out",
|
||||
"value_cache_out",
|
||||
"k_block_means_out"})
|
||||
.SetInplaceMap({{
|
||||
"q_input", "q_input_out"},
|
||||
{"k_input", "k_input_out"},
|
||||
{"v_input", "v_input_out"},
|
||||
{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"},
|
||||
{"k_block_means", "k_block_means_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MobaAttention));
|
||||
204
custom_ops/gpu_ops/moba_attn/moba_attn.h
Normal file
204
custom_ops/gpu_ops/moba_attn/moba_attn.h
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void MobaDecoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& rope_sin_cos,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
void MobaEncoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_q,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
void MobaDecoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
|
||||
void FusedBlockMeanAndRope(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
std::vector<paddle::Tensor> GetCurCuSeqLenk(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const int pack_size);
|
||||
|
||||
std::vector<paddle::Tensor> MobaQKGemm(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const bool is_split_kv,
|
||||
const int use_moba_seq_limit);
|
||||
|
||||
std::vector<paddle::Tensor> QkSortDecoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit);
|
||||
|
||||
void GetKVFromCache(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str);
|
||||
|
||||
|
||||
void MobaEncoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& out,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length);
|
||||
|
||||
std::vector<paddle::Tensor> QkSortEncoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit);
|
||||
|
||||
std::vector<paddle::Tensor> MobaMlpEinsum(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& attn_gate_weight,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_len,
|
||||
const int kv_head_num);
|
||||
748
custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp
Normal file
748
custom_ops/gpu_ops/moba_attn/moba_attn_utils.hpp
Normal file
@@ -0,0 +1,748 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/int_tuple.hpp"
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cub/cub.cuh>
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<phi::dtype::float16> {
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PackedHalf<phi::dtype::bfloat16> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct HalfSub;
|
||||
|
||||
template<>
|
||||
struct HalfSub<cutlass::half_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfSub<cutlass::bfloat16_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
*reinterpret_cast<nv_bfloat162*>(result_ptr) -= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct HalfMul;
|
||||
|
||||
template<>
|
||||
struct HalfMul<cutlass::half_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(*result_ptr) : "r"(*result_ptr), "r"(magic_num));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMul<cutlass::bfloat16_t> {
|
||||
inline __device__ void operator()(uint32_t* result_ptr, const uint32_t magic_num) {
|
||||
*reinterpret_cast<nv_bfloat162*>(result_ptr) *= *reinterpret_cast<const nv_bfloat162*>(&magic_num);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct HalfMax;
|
||||
template<>
|
||||
struct HalfMax<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMax<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct HalfMin;
|
||||
template<>
|
||||
struct HalfMin<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("min.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct HalfMin<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("min.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct MinOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinOp<float> {
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<typename T, bool Is_K>
|
||||
inline __device__ static void convert_c8_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
|
||||
uint32_t* half_result_ptr = reinterpret_cast<uint32_t*>(dst);
|
||||
if constexpr (std::is_same_v<T, cutlass::bfloat16_t>) {
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
float fp32_intermediates[4];
|
||||
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(*src, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(*src, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(*src, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(*src, fp32_base, 0x7653);
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
fp32_intermediates[ii] -= 8388608.f;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
half_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
||||
}
|
||||
} else {
|
||||
static constexpr uint32_t head_for_fp16 = 0x64006400;
|
||||
half_result_ptr[0] = __byte_perm(*src, head_for_fp16, 0x7150);
|
||||
half_result_ptr[1] = __byte_perm(*src, head_for_fp16, 0x7352);
|
||||
}
|
||||
|
||||
using pack_half = typename PackedHalf<T>::Type;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++){
|
||||
if constexpr (Is_K) {
|
||||
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + i * 2));
|
||||
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + i * 2));
|
||||
} else {
|
||||
pack_half bias;
|
||||
pack_half scale;
|
||||
bias.x = cache_zp[0];
|
||||
bias.y = cache_zp[0];
|
||||
scale.x = cache_scale[0];
|
||||
scale.y = cache_scale[0];
|
||||
HalfSub<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
|
||||
HalfMul<T>()(half_result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, bool Is_K>
|
||||
inline __device__ static void convert_c4_2_half(uint32_t *src, T *dst, const T *cache_scale, const T* cache_zp) {
|
||||
using pack_half = typename PackedHalf<T>::Type;
|
||||
static constexpr uint32_t MASK = 0x0f0f0f0f;
|
||||
static constexpr uint32_t head_for_fp16 = std::is_same_v<T, cutlass::bfloat16_t> ? 0x43004300 : 0x64006400;
|
||||
static constexpr uint32_t mask_for_c42fp16_one = 0x7253;
|
||||
static constexpr uint32_t mask_for_c42fp16_two = 0x7051;
|
||||
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(dst);
|
||||
uint32_t source = *reinterpret_cast<uint32_t const*>(src);
|
||||
// source = {e0 e4 e1 e5 e2 e6 e3 e7}
|
||||
uint32_t bottom_i4s = source & MASK;
|
||||
// bottom_i4s = {0 e4 0 e5 0 e6 0 e7}
|
||||
uint32_t top_i4s = (source >> 4) & MASK;
|
||||
// top_i4s = {0 e0 0 e1 0 e2 0 e3}
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[0]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
|
||||
// result_ptr[0] = {e0 e1}
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[1]) : "r"(top_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[2]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_one));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_ptr[3]) : "r"(bottom_i4s), "n"(head_for_fp16), "n"(mask_for_c42fp16_two));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if constexpr (Is_K) {
|
||||
const int ith_col = i % 2 * 2;
|
||||
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_zp + ith_col));
|
||||
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(cache_scale + ith_col));
|
||||
} else {
|
||||
const int ith_col = i / 2;
|
||||
pack_half bias;
|
||||
pack_half scale;
|
||||
bias.x = cache_zp[ith_col];
|
||||
bias.y = cache_zp[ith_col];
|
||||
scale.x = cache_scale[ith_col];
|
||||
scale.y = cache_scale[ith_col];
|
||||
HalfSub<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&bias));
|
||||
HalfMul<T>()(result_ptr + i, *reinterpret_cast<const uint32_t*>(&scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
|
||||
inline __device__ void gemm_qk_quant(
|
||||
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
|
||||
Tensor4 const& sB, TiledMma tiled_mma,
|
||||
ThrCopy0 smem_thr_copy_A,
|
||||
TiledCopy0 smem_tiled_copy_A,
|
||||
const int32_t tidx,
|
||||
const T * cache_scale, const T * cache_zp) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||
}
|
||||
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (kDataNumPer2Byte / 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
|
||||
}
|
||||
}
|
||||
if constexpr (kDataNumPer2Byte == 4) {
|
||||
convert_c4_2_half<T, true>(sBdata + i * kHeadDim, tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
|
||||
} else {
|
||||
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2), tCrB.data(), cache_scale + i * 4, cache_zp + i * 4);
|
||||
convert_c8_2_half<T, true>(sBdata + i * (kHeadDim * 2) + 1, tCrB.data() + 4, cache_scale + i * 4, cache_zp + i * 4);
|
||||
}
|
||||
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename CacheKV_traits, typename T, int kHeadDim, int kDataNumPer2Byte, bool A_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename ThrCopy0, typename TiledCopy0>
|
||||
inline __device__ void gemm_value_quant(
|
||||
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCsA, Tensor3 &tCrB,
|
||||
Tensor4 const& sB, TiledMma tiled_mma,
|
||||
ThrCopy0 smem_thr_copy_A,
|
||||
TiledCopy0 smem_tiled_copy_A,
|
||||
int32_t tidx,
|
||||
const T * cache_scale, const T * cache_zp) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
|
||||
}
|
||||
uint32_t *sBdata = reinterpret_cast<uint32_t *>(sB.data().get()) + tidx * (2 * kDataNumPer2Byte / 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
const int cur_idx = i * kHeadDim * (2 * kDataNumPer2Byte / 4);
|
||||
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) {
|
||||
copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
|
||||
}
|
||||
}
|
||||
if constexpr (kDataNumPer2Byte == 4) {
|
||||
convert_c4_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
|
||||
convert_c4_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
|
||||
} else {
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx, tCrB.data(), cache_scale, cache_zp);
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx + 1, tCrB.data() + 4, cache_scale + 1, cache_zp + 1);
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx + 2, tCrB.data() + 8, cache_scale + 2, cache_zp + 2);
|
||||
convert_c8_2_half<T, false>(sBdata + cur_idx + 3, tCrB.data() + 12, cache_scale + 3, cache_zp + 3);
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB, acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kMiLen, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &scores, const uint32_t warp_id, const uint32_t col, const uint32_t reamin_seq_len) {
|
||||
const int cols = size<1>(scores) / 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < cols; ++ni) {
|
||||
const int col_index = warp_id * 8 + ni * 32 + col * 2;
|
||||
if (col_index >= reamin_seq_len) {
|
||||
scores(mi, ni * 2) = -INFINITY;
|
||||
}
|
||||
if (col_index + 1 >= reamin_seq_len) {
|
||||
scores(mi, ni * 2 + 1) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template<int kMiLen, typename Engine0, typename Layout0, typename T>
|
||||
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, T *scores_max){
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
MaxOp<T> max_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ni++) {
|
||||
scores_max[mi] = max_op(scores_max[mi], tensor(mi, ni));
|
||||
}
|
||||
scores_max[mi] = Allreduce<4>::run(scores_max[mi], max_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <int kMiLen, typename Engine0, typename Layout0, typename T>
|
||||
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, T const *max, T *sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
const float max_scaled = max[mi] * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = expf(tensor(mi, ni) * scale - max_scaled);
|
||||
sum[mi] += tensor(mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename paddle_type>
|
||||
struct cuteType;
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
__forceinline__ __device__ auto float_2_half2(const float x) {
|
||||
if constexpr (std::is_same<T, cutlass::half_t>::value) {
|
||||
return __float2half2_rn(x);
|
||||
} else {
|
||||
return __float2bfloat162_rn(x);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct uint16 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
};
|
||||
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
template<int BYTES>
|
||||
struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<64> {
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
template<typename Elt_type, uint32_t NUM_ELT>
|
||||
struct Vec {
|
||||
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
|
||||
Alias_type data;
|
||||
|
||||
inline __device__ Vec() {}
|
||||
|
||||
template<typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void load_from(const void *base_ptr) {
|
||||
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
|
||||
}
|
||||
|
||||
|
||||
inline __device__ void store_to(void *base_ptr) {
|
||||
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
|
||||
}
|
||||
|
||||
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void set_zero() {
|
||||
constexpr int size = sizeof(Vec_type) / sizeof(int);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; ++i) {
|
||||
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
|
||||
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, int PackSize>
|
||||
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
|
||||
static_assert(PackSize % 2 == 0);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < PackSize / 2; i++) {
|
||||
const float cos_inv_freq = cos.data.elt[i];
|
||||
const float sin_inv_freq = sin.data.elt[i];
|
||||
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
|
||||
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
|
||||
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
|
||||
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
|
||||
__forceinline__ __device__ void copy(
|
||||
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &identity_MN,
|
||||
const int max_MN = 0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename ThrCopy0, typename ThrCopy1,
|
||||
typename TiledCopy0, typename TiledCopy1>
|
||||
inline __device__ void gemm(
|
||||
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
ThrCopy0 &smem_thr_copy_A, ThrCopy1 &smem_thr_copy_B,
|
||||
TiledCopy0 &smem_tiled_copy_A, TiledCopy1 &smem_tiled_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));
|
||||
|
||||
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) { copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||
if (!B_in_regs) { copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template<typename T, typename ReductionOp, int block_size>
|
||||
__inline__ __device__ T BlockAllReduce(T val) {
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
|
||||
if (threadIdx.x == 0) { result_broadcast = result; }
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
}
|
||||
|
||||
template<typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename ReductionOp, int thread_group_width = 32>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
ReductionOp op;
|
||||
#pragma unroll
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
template <int kPackSize, int knthreads>
|
||||
__device__ inline int get_data_count(const float * src, const float limit_value) {
|
||||
int count = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (src[i] >= limit_value) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
count = BlockAllReduce<int, SumOp<int>, knthreads>(count);
|
||||
return count;
|
||||
}
|
||||
@@ -0,0 +1,802 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_decoder_attn_kernel.h"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template<bool Is_first, int kMiLen, typename Tensor0, typename Tensor1, typename T>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &acc_o, const T *scores_max, const T *scores_max_prev, T * scores_sum, const float softmax_scale) {
|
||||
if (Is_first) {
|
||||
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
|
||||
} else {
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
const float scores_scale = expf((scores_max_prev[mi] - scores_max[mi]) * softmax_scale);
|
||||
scores_sum[mi] *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale;
|
||||
}
|
||||
}
|
||||
scale_apply_exp2<kMiLen>(scores, scores_max, scores_sum, softmax_scale);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
__global__ __launch_bounds__(Kernel_traits::kNThreads) void moba_decoder_attention_kernel(ParamType params) {
|
||||
using cuteType = typename Kernel_traits::cuteType;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using CacheKV_traits = typename Kernel_traits::CacheKV_traits;
|
||||
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int32_t kHeadDimKV = Kernel_traits::kHeadDimKV;
|
||||
constexpr int32_t kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int32_t kBlockSize = Kernel_traits::kBlockSize;
|
||||
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
|
||||
constexpr int32_t kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int32_t kTileN = Kernel_traits::kTileN;
|
||||
constexpr int32_t kBlockN = kTileN * kBlockSize;
|
||||
constexpr int32_t kDataBits = Kernel_traits::kDataBits;
|
||||
constexpr int32_t kMiLen = (kGqaGroupSize + 7) / 8;
|
||||
|
||||
const int32_t bi = blockIdx.y;
|
||||
const int32_t tidx = threadIdx.x;
|
||||
const int32_t partition_idx = blockIdx.x;
|
||||
const int32_t kv_head_idx = blockIdx.z;
|
||||
const int32_t q_head_idx = kv_head_idx * kGqaGroupSize;
|
||||
|
||||
const int32_t seq_len = params.seq_lens_decoder[bi] == 0 ? 0 : params.seq_lens_decoder[bi] + 1;
|
||||
|
||||
const int32_t head_num = params.head_num;
|
||||
const int32_t kv_head_num = params.kv_head_num;
|
||||
|
||||
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
|
||||
|
||||
if (seq_len == 0 || partition_idx >= partition_num) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (seq_len >= params.use_moba_seq_limit && params.qk_gate_topk_idx_ptr[(bi * kv_head_num + kv_head_idx) * Kernel_traits::kMaxN + partition_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const int q_bias_offset = q_head_idx * kHeadDim;
|
||||
|
||||
cuteType * q_input = reinterpret_cast<cuteType *>(params.q_input) + params.cu_seq_q[bi] * head_num * kHeadDim;
|
||||
|
||||
Tensor gQ = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<const cuteType *>(q_input) + q_bias_offset),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
const int32_t block_idx = partition_idx * kTileN;
|
||||
const int* block_table = params.block_table + bi * params.max_num_blocks_per_seq + block_idx;
|
||||
const int32_t physical_block_number = block_table[0];
|
||||
|
||||
const int32_t cache_offset = (physical_block_number * kv_head_num + kv_head_idx) * kBlockSize * kHeadDimKV;
|
||||
|
||||
Tensor gK = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_k) + cache_offset),
|
||||
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
|
||||
Stride<Int<kHeadDimKV>, _1>{});
|
||||
|
||||
Tensor gV = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<const cuteType *>(params.cache_v) + cache_offset),
|
||||
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{},
|
||||
Stride<Int<kHeadDimKV>, _1>{});
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
Tensor sQ = make_tensor(
|
||||
make_smem_ptr(reinterpret_cast<cuteType *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQ{});
|
||||
Tensor sQK = make_tensor(
|
||||
sQ.data() + size(sQ),
|
||||
typename Kernel_traits::SmemLayoutQK{});
|
||||
|
||||
Tensor sK = make_tensor(sQK.data() + size(sQK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||
__shared__ ElementAccum scores_warp[kNWarps][kMiLen * kBlockM];
|
||||
|
||||
auto gmem_tiled_copy_Q = typename Kernel_traits::GmemTiledCopyQ{};
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
|
||||
auto gmem_tiled_copy_KV = typename Kernel_traits::GmemTiledCopyKV{};
|
||||
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
|
||||
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV);
|
||||
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
|
||||
|
||||
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
|
||||
|
||||
Tensor cKV = make_identity_tensor(make_shape(kBlockSize, kHeadDim));
|
||||
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
using SmemCopyAtom = typename Kernel_traits::SmemCopyAtom;
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
||||
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
|
||||
|
||||
Tensor tSsQK = smem_thr_copy_Q.partition_S(sQK);
|
||||
Tensor tSrQK = thr_mma.partition_fragment_A(sQK);
|
||||
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);
|
||||
|
||||
copy<false>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, kGqaGroupSize);
|
||||
|
||||
|
||||
cute::cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
|
||||
const int32_t remain_seq_len = seq_len - partition_idx * kTileN * kBlockSize;
|
||||
|
||||
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
|
||||
|
||||
cute::cp_async_fence();
|
||||
|
||||
const int32_t warp_id = tidx / 32;
|
||||
const int32_t lane_id = tidx % 32;
|
||||
const int32_t row = lane_id / 4;
|
||||
const int32_t col = lane_id % 4;
|
||||
const int row_idx = tidx / 4;
|
||||
|
||||
using scale_k_vec = Vec<cuteType, 32>;
|
||||
using scale_v_vec = Vec<cuteType, 4>;
|
||||
|
||||
scale_k_vec scale_k;
|
||||
scale_k_vec zp_k;
|
||||
scale_v_vec scale_v;
|
||||
scale_v_vec zp_v;
|
||||
if constexpr (kDataBits == 4) {
|
||||
scale_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_dequant_scale + kv_head_idx * kHeadDim + col * 32);
|
||||
zp_k = *reinterpret_cast<const scale_k_vec*>(params.cache_k_zp + kv_head_idx * kHeadDim + col * 32);
|
||||
scale_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_dequant_scale + kv_head_idx * kHeadDim + row_idx * 4);
|
||||
zp_v = *reinterpret_cast<const scale_v_vec*>(params.cache_v_zp + kv_head_idx * kHeadDim + row_idx * 4);
|
||||
}
|
||||
|
||||
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
clear(acc_o);
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockSize>>{});
|
||||
|
||||
ElementAccum scores_max[kMiLen];
|
||||
ElementAccum scores_max_prev[kMiLen];
|
||||
ElementAccum scores_sum[kMiLen];
|
||||
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_max[mi] = -INFINITY;
|
||||
scores_sum[mi] = 0;
|
||||
}
|
||||
|
||||
const int cache_offset_step = kv_head_num * kBlockSize * kHeadDimKV;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < kTileN; ++n) {
|
||||
const int cur_remain_seq_len = remain_seq_len - n * kBlockSize;
|
||||
|
||||
if (cur_remain_seq_len <= 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
clear(acc_s);
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (n > 0) {
|
||||
tVgV.data() = tVgV.data() + (block_table[n] - block_table[n - 1]) * cache_offset_step;
|
||||
}
|
||||
|
||||
copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV);
|
||||
|
||||
cute::cp_async_fence();
|
||||
|
||||
if constexpr (kDataBits == 16) {
|
||||
if (n == 0) {
|
||||
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
} else {
|
||||
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
}
|
||||
} else {
|
||||
Tensor tSrKQuant = make_tensor<cuteType>(
|
||||
Layout<
|
||||
Shape<Shape<_2, _2>, Int<kBlockSize / 32>>,
|
||||
Stride<Shape<_1, _2>, _4>>{});
|
||||
if (n == 0) {
|
||||
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
|
||||
} else {
|
||||
gemm_qk_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits, true>(acc_s, tSrQ, tSsQ, tSrKQuant, sK, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_k.data.elt, zp_k.data.elt);
|
||||
}
|
||||
}
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
if (partition_idx == partition_num - 1 && cur_remain_seq_len < kBlockSize) {
|
||||
apply_mask<kMiLen>(scores, warp_id, col, cur_remain_seq_len);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_max_prev[mi] = scores_max[mi];
|
||||
}
|
||||
|
||||
reduce_max<kMiLen>(scores, scores_max);
|
||||
|
||||
if (col == 0) {
|
||||
scores_warp[warp_id][row] = scores_max[0];
|
||||
if constexpr (kMiLen > 1) {
|
||||
scores_warp[warp_id][row + 8] = scores_max[1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
MaxOp<ElementAccum> max_op;
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
float cur_max = scores_warp[0][tidx];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 1; i < kNWarps; ++i) {
|
||||
cur_max = max_op(scores_warp[i][tidx], cur_max);
|
||||
}
|
||||
scores_warp[0][tidx] = cur_max;
|
||||
}
|
||||
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (cur_remain_seq_len > kBlockSize && n < kTileN - 1) {
|
||||
tKgK.data() = tKgK.data() + (block_table[n + 1] - block_table[n]) * cache_offset_step;
|
||||
copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV);
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_max[mi] = scores_warp[0][row + mi * 8];
|
||||
}
|
||||
|
||||
if (n == 0) {
|
||||
softmax_rescale_o<true, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
|
||||
} else {
|
||||
softmax_rescale_o<false, kMiLen>(scores, acc_o, scores_max, scores_max_prev, scores_sum, params.inv_sqrt_dh);
|
||||
}
|
||||
|
||||
Tensor rS = convert_type<cuteType>(acc_s);
|
||||
|
||||
Tensor trQK = smem_thr_copy_O.retile_S(rS);
|
||||
Tensor tsQK = smem_thr_copy_O.partition_D(sQK);
|
||||
cute::copy(smem_tiled_copy_O, trQK, tsQK);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (kDataBits == 16) {
|
||||
gemm(acc_o, tSrQK, tOrVt, tSsQK, tOsVt, tiled_mma, smem_thr_copy_Q, smem_thr_copy_V, smem_tiled_copy_Q, smem_tiled_copy_V);
|
||||
} else {
|
||||
Tensor tSrVQuant = make_tensor<cuteType>(
|
||||
Layout<
|
||||
Shape<_4, Shape<_2, _2>>,
|
||||
Stride<_1, Shape<_4, _8>>>{});
|
||||
gemm_value_quant<CacheKV_traits, cuteType, kHeadDim, kDataBits>(acc_o, tSrQK, tSsQK, tSrVQuant, sV, tiled_mma, smem_thr_copy_Q, smem_tiled_copy_Q, tidx, scale_v.data.elt, zp_v.data.elt);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t pack_max_partition_num = (params.max_num_partitions + 3) / 4 * 4;
|
||||
uint32_t max_sum_offset = bi * pack_max_partition_num * head_num + (tidx + q_head_idx) * pack_max_partition_num + partition_idx;
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
params.maxs[max_sum_offset] = scores_warp[0][tidx] * params.inv_sqrt_dh;
|
||||
}
|
||||
|
||||
SumOp<ElementAccum> sum_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMiLen; ++mi) {
|
||||
scores_sum[mi] = Allreduce<4>::run(scores_sum[mi], sum_op);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (col == 0) {
|
||||
scores_warp[warp_id][row] = scores_sum[0];
|
||||
if constexpr (kMiLen > 1) {
|
||||
scores_warp[warp_id][row + 8] = scores_sum[1];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Tensor rO = convert_type<cuteType>(acc_o);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sQ);
|
||||
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
float cur_sum = scores_warp[0][tidx];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 1; i < kNWarps; ++i) {
|
||||
cur_sum = sum_op(scores_warp[i][tidx], cur_sum);
|
||||
}
|
||||
scores_warp[0][tidx] = cur_sum;
|
||||
}
|
||||
|
||||
Tensor gO = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<cuteType *>(params.partition_attn_out) + ((bi * params.max_num_partitions + partition_idx) * head_num + q_head_idx)* kHeadDim),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
auto gmem_tiled_copy_O = typename Kernel_traits::GmemTiledCopyO{};
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sQ);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
constexpr int32_t copy_size = kGqaGroupSize * 16;
|
||||
__syncthreads();
|
||||
|
||||
if (tidx < copy_size) {
|
||||
cute::copy(gmem_tiled_copy_O, tOsO(_, 0, _), tOgO(_, 0, _));
|
||||
}
|
||||
|
||||
if constexpr (kMiLen > 1) {
|
||||
if (tidx < copy_size - 128) {
|
||||
cute::copy(gmem_tiled_copy_O, tOsO(_, 1, _), tOgO(_, 1, _));
|
||||
}
|
||||
}
|
||||
|
||||
if (tidx < kGqaGroupSize) {
|
||||
params.sums[max_sum_offset] = scores_warp[0][tidx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
inline __device__ float caluate_logit_scale(const int partition_num, const int pack_max_partition_num, ParamType ¶ms, char * shared_mem, const int seq_len, const int *qk_gate_topk_idx_ptr) {
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceThreads = Kernel_traits::kNReduceThreads;
|
||||
const int32_t bi = blockIdx.z;
|
||||
const int32_t tidx = threadIdx.x;
|
||||
const int32_t head_idx = blockIdx.y;
|
||||
const int32_t head_num = params.head_num;
|
||||
|
||||
using float_vec = Vec<float, kNFloatPacksize>;
|
||||
const int32_t offset = bi * head_num * pack_max_partition_num + head_idx * pack_max_partition_num;
|
||||
|
||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||
const float* max_logits_ptr = params.maxs + offset;
|
||||
float global_max_logit = -FLT_MAX;
|
||||
|
||||
int32_t idx = tidx * kNFloatPacksize;
|
||||
#pragma unroll
|
||||
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
|
||||
float_vec cur_max = *reinterpret_cast<const float_vec*>(max_logits_ptr + idx);
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
|
||||
}
|
||||
} else {
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max.data.elt[j]);
|
||||
}
|
||||
}
|
||||
cur_max.store_to(shared_max_logits + idx);
|
||||
}
|
||||
|
||||
const int32_t packed_data_num = partition_num / kNFloatPacksize * kNFloatPacksize;
|
||||
|
||||
idx = packed_data_num + tidx;
|
||||
#pragma unroll
|
||||
for (; idx < partition_num; idx += kNReduceThreads) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx] != 0) {
|
||||
float cur_max = max_logits_ptr[idx];
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max);
|
||||
shared_max_logits[idx] = cur_max;
|
||||
}
|
||||
} else {
|
||||
float cur_max = max_logits_ptr[idx];
|
||||
global_max_logit = fmaxf(global_max_logit, cur_max);
|
||||
shared_max_logits[idx] = cur_max;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
global_max_logit = BlockAllReduce<float, MaxOp<float>, kNReduceThreads>(global_max_logit);
|
||||
|
||||
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
|
||||
const float* exp_sums_ptr = params.sums + offset;
|
||||
float global_exp_sum = 0.0f;
|
||||
|
||||
idx = tidx * kNFloatPacksize;
|
||||
#pragma unroll
|
||||
for (; idx <= partition_num - kNFloatPacksize; idx += kNReduceThreads * kNFloatPacksize) {
|
||||
float_vec share_max = *reinterpret_cast<const float_vec*>(shared_max_logits + idx);
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kNFloatPacksize; ++j) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx + j] != 0) {
|
||||
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_max.data.elt[j] = exp_sub_max;
|
||||
}
|
||||
} else {
|
||||
float exp_sub_max = expf(share_max.data.elt[j] - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx + j] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_max.data.elt[j] = exp_sub_max;
|
||||
}
|
||||
}
|
||||
share_max.store_to(share_sum_scale + idx);
|
||||
}
|
||||
|
||||
idx = packed_data_num + tidx;
|
||||
#pragma unroll
|
||||
for (; idx < partition_num; idx += kNReduceThreads) {
|
||||
if (seq_len >= params.use_moba_seq_limit) {
|
||||
if (qk_gate_topk_idx_ptr[idx] != 0) {
|
||||
float share_max = shared_max_logits[idx];
|
||||
float exp_sub_max = expf(share_max - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_sum_scale[idx] = exp_sub_max;
|
||||
}
|
||||
} else {
|
||||
float share_max = shared_max_logits[idx];
|
||||
float exp_sub_max = expf(share_max - global_max_logit);
|
||||
float rescaled_exp_sum = exp_sums_ptr[idx] * exp_sub_max;
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
share_sum_scale[idx] = exp_sub_max;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
global_exp_sum = BlockAllReduce<float, SumOp<float>, kNReduceThreads>(global_exp_sum);
|
||||
|
||||
const float inv_global_exp_sum = fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||
return inv_global_exp_sum;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
__global__ void __launch_bounds__(Kernel_traits::kNReduceThreads) moba_decoder_attention_merge_kernel(ParamType params) {
|
||||
using cuteType = typename Kernel_traits::cuteType;
|
||||
constexpr int32_t kBlockN = Kernel_traits::kTileN * Kernel_traits::kBlockSize;
|
||||
constexpr int32_t kNReducePacksize = 16 / sizeof(cuteType);
|
||||
constexpr int32_t kNFloatPacksize = 16 / sizeof(float);
|
||||
constexpr int32_t kNReduceWarps = Kernel_traits::kNReduceWarps;
|
||||
constexpr int32_t kHeadDim = Kernel_traits::kHeadDim;
|
||||
const int32_t bi = blockIdx.z;
|
||||
const int32_t headdim_idx = kNReducePacksize * kNReduceWarps * blockIdx.x;
|
||||
const int32_t tidx = threadIdx.x;
|
||||
const int32_t head_idx = blockIdx.y;
|
||||
const int32_t warp_id = tidx / 32;
|
||||
const int32_t lane_id = tidx % 32;
|
||||
const int32_t seq_len = params.seq_lens_decoder[bi] + 1;
|
||||
const int32_t head_num = params.head_num;
|
||||
using pack_half = typename PackedHalf<cuteType>::Type;
|
||||
|
||||
|
||||
if (params.seq_lens_decoder[bi] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ char shared_mem[];
|
||||
|
||||
const int32_t partition_num = (seq_len + kBlockN - 1) / kBlockN;
|
||||
const int32_t pack_max_partition_num = (params.max_num_partitions + kNFloatPacksize - 1) / kNFloatPacksize * kNFloatPacksize;
|
||||
|
||||
float* share_sum_scale = reinterpret_cast<float*>(shared_mem + sizeof(float) * pack_max_partition_num);
|
||||
|
||||
constexpr int32_t kGqaGroupSize = Kernel_traits::kGqaGroupSize;
|
||||
const int kv_head_idx = head_idx / Kernel_traits::kGqaGroupSize;
|
||||
const int * qk_gate_topk_idx_ptr = params.qk_gate_topk_idx_ptr + (bi * params.kv_head_num + kv_head_idx) * Kernel_traits::kMaxN;
|
||||
|
||||
float inv_global_exp_sum = caluate_logit_scale<Kernel_traits>(partition_num, pack_max_partition_num, params, shared_mem, seq_len, qk_gate_topk_idx_ptr);
|
||||
|
||||
|
||||
using T_vec = Vec<cuteType, kNReducePacksize>;
|
||||
|
||||
cuteType* partition_attn_out = reinterpret_cast<cuteType*>(params.partition_attn_out) + bi * head_num * params.max_num_partitions * kHeadDim + head_idx * kHeadDim + headdim_idx;
|
||||
|
||||
Vec<float, kNReducePacksize> acc;
|
||||
acc.set_zero();
|
||||
#pragma unroll
|
||||
for (int idx = lane_id; idx < partition_num; idx += 32) {
|
||||
if (seq_len >= params.use_moba_seq_limit && qk_gate_topk_idx_ptr[idx] == 0) {
|
||||
continue;
|
||||
}
|
||||
T_vec sub_logits = *reinterpret_cast<T_vec*>(&partition_attn_out[idx * head_num * kHeadDim + warp_id * kNReducePacksize]);
|
||||
float scale = share_sum_scale[idx];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kNReducePacksize; ++k) {
|
||||
acc.data.elt[k] += static_cast<float>(sub_logits.data.elt[k]) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
T_vec out;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kNReducePacksize; ++k) {
|
||||
out.data.elt[k] = static_cast<cuteType>(WarpAllReduce<float, SumOp<float>>(acc.data.elt[k]) * inv_global_exp_sum);
|
||||
}
|
||||
|
||||
const int ori_token_idx = params.cu_seq_q[bi];
|
||||
cuteType * attn_out = reinterpret_cast<cuteType *>(params.attn_out) + ori_token_idx * head_num * kHeadDim + head_idx * kHeadDim + headdim_idx + warp_id * kNReducePacksize;
|
||||
|
||||
if (lane_id == 0) {
|
||||
out.store_to(attn_out);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename ParamType>
|
||||
void run_moba_decoder_attn(ParamType ¶ms, cudaStream_t stream) {
|
||||
dim3 grid;
|
||||
grid.x = params.max_num_partitions;
|
||||
grid.y = params.batch_size;
|
||||
grid.z = params.kv_head_num;
|
||||
constexpr int smem_size = Kernel_traits::kShareMemSize;
|
||||
constexpr auto kernel = &moba_decoder_attention_kernel<Kernel_traits, ParamType>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
|
||||
int32_t reduce_shared_mem_size = 2 * (params.max_num_partitions + 4) * sizeof(float);
|
||||
constexpr int32_t pack_size = 16 / sizeof(typename Kernel_traits::cuteType);
|
||||
static_assert(Kernel_traits::kHeadDim % pack_size == 0);
|
||||
static_assert((Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps) % pack_size == 0);
|
||||
grid.x = Kernel_traits::kHeadDim / Kernel_traits::kNReduceWarps / pack_size;
|
||||
grid.y = params.head_num;
|
||||
grid.z = params.batch_size;
|
||||
auto reduce_kernel = &moba_decoder_attention_merge_kernel<Kernel_traits, ParamType>;
|
||||
|
||||
if (reduce_shared_mem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
reduce_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size);
|
||||
}
|
||||
reduce_kernel<<<grid, Kernel_traits::kNReduceThreads, reduce_shared_mem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
|
||||
template<typename cute_type, int kCacheBits, int kBlockN, int kMaxN, typename ParamType>
|
||||
void run_moba_decoder_attn_hdim128(ParamType ¶ms, cudaStream_t stream) {
|
||||
const int gqaGroupSize = params.head_num / params.kv_head_num;
|
||||
using CacheKVTraits = CacheKV_quant_traits<cute_type, kCacheBits>;
|
||||
constexpr int kTileN = kBlockN / CacheKVTraits::kBlockSize;
|
||||
switch (gqaGroupSize) {
|
||||
case 12: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<12, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<8, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<7, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<6, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<5, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
run_moba_decoder_attn<moba_decoder_attn_kernel_traits<4, kTileN, kMaxN,CacheKVTraits>>(params, stream);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"DecoderBlockAttention not implemented for gqaGroupSize = %d", gqaGroupSize));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void DispatchMobaDecoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int batch_size,
|
||||
const int max_input_length,
|
||||
const int use_moba_seq_limit,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
const int kMobaBlockSize = 128;
|
||||
const int kMaxN = 1024;
|
||||
|
||||
constexpr int max_seq_per_block = kMobaBlockSize;
|
||||
moba_decoder_attn_params<cute_type> params;
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
const uint32_t max_num_partitions = (max_seq_k + max_seq_per_block) / max_seq_per_block;
|
||||
assert(head_dim == 128);
|
||||
|
||||
paddle::Tensor maxs = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
|
||||
paddle::Tensor sums = paddle::empty({batch_size, head_num, (max_num_partitions + 3) / 4 * 4}, paddle::DataType::FLOAT32, q_input.place());
|
||||
paddle::Tensor partition_attn_out = paddle::empty({batch_size, max_num_partitions, head_num, head_dim}, q_input.dtype(), q_input.place());
|
||||
|
||||
params.q_input = reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>()));
|
||||
params.attn_out = reinterpret_cast<cute_type *>(const_cast<T*>(out.data<T>()));
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_len_decoder.data<int>());
|
||||
params.block_table = const_cast<int*>(block_tables.data<int>());
|
||||
params.max_input_length = max_input_length;
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_num_blocks_per_seq = block_tables.dims()[1];
|
||||
params.batch_size = batch_size;
|
||||
params.inv_sqrt_dh = 1.0f / std::sqrt(head_dim);
|
||||
params.max_num_partitions = max_num_partitions;
|
||||
params.maxs = reinterpret_cast<float*>(maxs.data<float>());
|
||||
params.sums = reinterpret_cast<float*>(sums.data<float>());
|
||||
params.partition_attn_out = reinterpret_cast<cute_type *>(partition_attn_out.data<T>());
|
||||
params.qk_gate_topk_idx_ptr = const_cast<int*>(qk_gate_topk_idx.data<int>());
|
||||
params.use_moba_seq_limit = use_moba_seq_limit;
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
|
||||
|
||||
if (cache_quant_type_str == "none") {
|
||||
params.cache_k = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k.data<T>()));
|
||||
params.cache_v = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v.data<T>()));
|
||||
run_moba_decoder_attn_hdim128<cute_type, 16, max_seq_per_block, kMaxN>(params, q_input.stream());
|
||||
} else {
|
||||
params.cache_k = const_cast<uint8_t*>(cache_k.data<uint8_t>());
|
||||
params.cache_v = const_cast<uint8_t*>(cache_v.data<uint8_t>());
|
||||
params.cache_k_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_quant_scale.get().data<T>()));
|
||||
params.cache_v_quant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_quant_scale.get().data<T>()));
|
||||
params.cache_k_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_dequant_scale.get().data<T>()));
|
||||
params.cache_v_dequant_scale = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_dequant_scale.get().data<T>()));
|
||||
params.cache_k_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_k_zero_points.get().data<T>()));
|
||||
params.cache_v_zp = reinterpret_cast<cute_type *>(const_cast<T*>(cache_v_zero_points.get().data<T>()));
|
||||
if (cache_quant_type_str == "cache_int8_zp") {
|
||||
run_moba_decoder_attn_hdim128<cute_type, 8, max_seq_per_block, kMaxN>(params, q_input.stream());
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
run_moba_decoder_attn_hdim128<cute_type, 4, max_seq_per_block, kMaxN>(params, q_input.stream());
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"GQA Attention not implemented for cache_quant_type_str = %s", cache_quant_type_str.c_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MobaDecoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
const int batch_size = block_tables.dims()[0];
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
return DispatchMobaDecoderAttn<phi::dtype::float16>(
|
||||
q_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
block_tables,
|
||||
k_block_means,
|
||||
out,
|
||||
qk_gate_topk_idx,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
batch_size,
|
||||
max_input_length,
|
||||
use_moba_seq_limit,
|
||||
cache_quant_type_str);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return DispatchMobaDecoderAttn<phi::dtype::bfloat16>(
|
||||
q_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
block_tables,
|
||||
k_block_means,
|
||||
out,
|
||||
qk_gate_topk_idx,
|
||||
cache_k_quant_scale,
|
||||
cache_v_quant_scale,
|
||||
cache_k_dequant_scale,
|
||||
cache_v_dequant_scale,
|
||||
cache_k_zero_points,
|
||||
cache_v_zero_points,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
batch_size,
|
||||
max_input_length,
|
||||
use_moba_seq_limit,
|
||||
cache_quant_type_str);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,225 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "../moba_attn_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
template <typename T>
|
||||
struct moba_decoder_attn_params {
|
||||
T *__restrict__ q_input;
|
||||
void *__restrict__ cache_k;
|
||||
void *__restrict__ cache_v;
|
||||
|
||||
T *__restrict__ attn_out;
|
||||
T *__restrict__ partition_attn_out;
|
||||
T *__restrict__ cache_k_dequant_scale;
|
||||
T *__restrict__ cache_v_dequant_scale;
|
||||
T *__restrict__ cache_k_quant_scale;
|
||||
T *__restrict__ cache_v_quant_scale;
|
||||
T *__restrict__ cache_k_zp;
|
||||
T *__restrict__ cache_v_zp;
|
||||
int * __restrict__ cu_seq_q;
|
||||
float * sums;
|
||||
float * maxs;
|
||||
int * seq_lens_encoder;
|
||||
int * seq_lens_decoder;
|
||||
int * block_table;
|
||||
int max_input_length;
|
||||
int max_seq_len;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_num_blocks_per_seq;
|
||||
float scale_softmax;
|
||||
int batch_size;
|
||||
int max_num_partitions;
|
||||
float inv_sqrt_dh;
|
||||
int *qk_gate_topk_idx_ptr;
|
||||
int use_moba_seq_limit;
|
||||
};
|
||||
|
||||
template <typename cute_type_, int DataBits_>
|
||||
struct CacheKV_quant_traits {
|
||||
using cuteType = cute_type_;
|
||||
static constexpr int kDataBits = DataBits_;
|
||||
static constexpr int kBlockSize = 64;
|
||||
static constexpr int kHeadDim = 128;
|
||||
static constexpr int kBlockKSmem = 64;
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockSize>, Int<kHeadDim>>{}));
|
||||
|
||||
static constexpr int kNWarps = 4;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
|
||||
static constexpr int kThreadPerValue = 16 / sizeof(cuteType);
|
||||
static constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
|
||||
Stride<Int<kThreadsPerRow>, _1>>;
|
||||
|
||||
using GmemTiledCopyQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<cuteType, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
|
||||
using ValLayoutMNK = Layout<Shape<_1,_4,_1>>;
|
||||
|
||||
using PermutationMNK = Tile<_16, Int<16 * kNWarps>, _16>;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
MMA_Atom_Arch,
|
||||
ValLayoutMNK,
|
||||
PermutationMNK>;
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, cuteType>;
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockSize>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockSize>>{}));
|
||||
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, cuteType>;
|
||||
|
||||
static constexpr int kShareMemSize = size(SmemLayoutKV{}) * 2 * sizeof(cuteType);
|
||||
};
|
||||
|
||||
template <int kGqaGroupSize_, int kTileN_, int kMaxN_, typename CacheKV_traits_>
|
||||
struct moba_decoder_attn_kernel_traits {
|
||||
using ElementAccum = float;
|
||||
using CacheKV_traits = CacheKV_traits_;
|
||||
using cuteType = typename CacheKV_traits::cuteType;
|
||||
static constexpr int kDataBits = CacheKV_traits::kDataBits;
|
||||
static constexpr int kTileN = kTileN_;
|
||||
static constexpr int kMaxN = kMaxN_;
|
||||
static constexpr int kGqaGroupSize = kGqaGroupSize_;
|
||||
static constexpr int kHeadDim = CacheKV_traits::kHeadDim;
|
||||
static constexpr int kHeadDimKV = kHeadDim / (16 / kDataBits);
|
||||
static constexpr int kMinGemmM = 16;
|
||||
static constexpr int kBlockM = (kGqaGroupSize + kMinGemmM - 1) / kMinGemmM * kMinGemmM;
|
||||
static constexpr int kBlockSize = CacheKV_traits::kBlockSize;
|
||||
static_assert(kGqaGroupSize <= 16);
|
||||
static constexpr int32_t kNWarps = CacheKV_traits::kNWarps;
|
||||
|
||||
static constexpr int kBlockKSmem = CacheKV_traits::kBlockKSmem;
|
||||
static constexpr int kBlockKVSmem = kHeadDimKV <= 64 ? kHeadDimKV : 64;
|
||||
static_assert(kHeadDim % kBlockKSmem == 0);
|
||||
static constexpr int kNReduceWarps = 4;
|
||||
static constexpr int kNReduceThreads = kNReduceWarps * 32;
|
||||
|
||||
|
||||
using SmemLayoutAtomQ = typename CacheKV_traits::SmemLayoutAtomQ;
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutQK = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kBlockSize>>{}));
|
||||
|
||||
using SmemLayoutAtomKV = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
Shape<Int<8>, Int<kBlockKVSmem>>,
|
||||
Stride<Int<kBlockKVSmem>, _1>>{}));
|
||||
|
||||
using SmemLayoutKV_ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKV{},
|
||||
Shape<Int<kBlockSize>, Int<kHeadDimKV>>{}));
|
||||
|
||||
using SmemLayoutKV = std::conditional_t<
|
||||
kDataBits == 16,
|
||||
SmemLayoutKV_,
|
||||
decltype(get_nonswizzle_portion(SmemLayoutKV_{}))
|
||||
>;
|
||||
|
||||
constexpr static int kBlockKVSize = kDataBits == 4 ? 32 : kBlockSize;
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockKVSize>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockKVSize>>{}));
|
||||
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
static constexpr int kThreadsPerRow = CacheKV_traits::kThreadsPerRow;
|
||||
static constexpr int kThreadsKVPerRow = kThreadsPerRow / (16 / kDataBits);
|
||||
static constexpr int kNThreads = CacheKV_traits::kNThreads;
|
||||
|
||||
using GmemKVLayoutAtom = Layout<
|
||||
Shape<Int<kNThreads / kThreadsKVPerRow>, Int<kThreadsKVPerRow>>,
|
||||
Stride<Int<kThreadsKVPerRow>, _1>>;
|
||||
|
||||
using SmemCopyAtom = typename CacheKV_traits::SmemCopyAtom;
|
||||
using TiledMma = typename CacheKV_traits::TiledMma;
|
||||
|
||||
static constexpr int kThreadPerValue = CacheKV_traits::kThreadPerValue;
|
||||
|
||||
using GmemTiledCopyQ = typename CacheKV_traits::GmemTiledCopyQ;
|
||||
using GmemLayoutAtom = typename CacheKV_traits::GmemLayoutAtom;
|
||||
using GmemTiledCopyKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cuteType>{},
|
||||
GmemKVLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
|
||||
using SmemCopyAtomTransposed = typename CacheKV_traits::SmemCopyAtomTransposed;
|
||||
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, cuteType>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, cuteType>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<3, 3, 3>{},
|
||||
Layout<
|
||||
Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
static constexpr int kShareMemSize = (size(SmemLayoutQ{}) + size(SmemLayoutQK{}) + size(SmemLayoutKV{}) * 2) * sizeof(cuteType);
|
||||
};
|
||||
@@ -0,0 +1,189 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "../moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
template <typename T, int kBlockSize, int kHeadDim, int moba_block_size, int kMaxN>
|
||||
__global__ void moba_decoder_attn_write_c16(
|
||||
const T * qkv_out,
|
||||
const T * qkv_bias,
|
||||
T * q_input,
|
||||
const int * cu_seq_q,
|
||||
const int * cu_seq_k,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
T * cache_k,
|
||||
T * cache_v,
|
||||
const int * block_tables,
|
||||
const float * rope_sin_cos,
|
||||
T *k_block_means,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int max_blocks_per_seq,
|
||||
const int max_input_length) {
|
||||
|
||||
int bidh = blockIdx.x;
|
||||
const int bidb = blockIdx.y;
|
||||
const int tidx = threadIdx.x;
|
||||
const int seq_len = seq_len_decoder[bidb];
|
||||
|
||||
if (seq_len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int kPackSize = 4;
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
using rope_type = Vec<float, kPackSize / 2>;
|
||||
SrcType src, bias, k_prev;
|
||||
rope_type sin, cos;
|
||||
const int bias_idx = bidh * kHeadDim + tidx * kPackSize;
|
||||
const int ori_token_idx = cu_seq_q[bidb];
|
||||
src.load_from(qkv_out + ori_token_idx * (head_num + 2 * kv_head_num) * kHeadDim + bias_idx);
|
||||
if (qkv_bias != nullptr) {
|
||||
bias.load_from(qkv_bias + bias_idx);
|
||||
src.add(bias);
|
||||
}
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const int32_t physical_block_number = block_table_now[seq_len / kBlockSize];
|
||||
|
||||
|
||||
if (bidh < head_num) {
|
||||
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(q_input + cu_seq_q[bidb] * head_num * kHeadDim + bias_idx);
|
||||
} else if (bidh < head_num + kv_head_num) {
|
||||
bidh -= head_num;
|
||||
const int token_in_blocks = seq_len % kBlockSize;
|
||||
const float * cos_rope = rope_sin_cos + seq_len * (kHeadDim / 2) + tidx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<T, kPackSize>(src, cos, sin);
|
||||
|
||||
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
|
||||
src.store_to(cache);
|
||||
|
||||
const int seq_len_block = seq_len / moba_block_size;
|
||||
|
||||
const int store_mean_idx = (bidb * kMaxN + seq_len_block) * kv_head_num * kHeadDim + bidh * kHeadDim + tidx * kPackSize;
|
||||
|
||||
if (seq_len % moba_block_size != 0) {
|
||||
const int token_num_prev = seq_len % moba_block_size;
|
||||
const float inv_tokens_sum = fdividef(1.0f, token_num_prev + 1);
|
||||
k_prev.load_from(k_block_means + store_mean_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
src.data.elt[i] = T(inv_tokens_sum * (float(src.data.elt[i]) + float(k_prev.data.elt[i]) * token_num_prev));
|
||||
}
|
||||
}
|
||||
|
||||
src.store_to(k_block_means + store_mean_idx);
|
||||
|
||||
} else {
|
||||
bidh -= head_num + kv_head_num;
|
||||
const int token_in_blocks = seq_len % kBlockSize;
|
||||
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + tidx * kPackSize + token_in_blocks * kHeadDim;
|
||||
src.store_to(cache);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void MobaDecoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& rope_sin_cos,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
constexpr int kThreads = 32;
|
||||
constexpr int kHeadDim = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
assert(kHeadDim == head_dim);
|
||||
constexpr int kBlockSize = 64;
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
const int batch_size = block_tables.dims()[0];
|
||||
if (cache_quant_type_str == "none") {
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = head_num + kv_head_num * 2;
|
||||
grid_dims.y = batch_size;
|
||||
if (qkv_out.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
|
||||
qkv_out.data<T>(),
|
||||
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
|
||||
const_cast<T*>(q_input.data<T>()),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T *>(cache_k.data<T>()),
|
||||
const_cast<T *>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
rope_sin_cos.data<float>(),
|
||||
const_cast<T*>(k_block_means.data<T>()),
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_blocks_per_seq,
|
||||
max_input_length);
|
||||
} else if (qkv_out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
moba_decoder_attn_write_c16<T, kBlockSize, kHeadDim, kMobaBlockSize, kMaxN><<<grid_dims, kThreads, 0, qkv_out.stream()>>>(
|
||||
qkv_out.data<T>(),
|
||||
qkv_bias ? qkv_bias.get().data<T>() : nullptr,
|
||||
const_cast<T*>(q_input.data<T>()),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T *>(cache_k.data<T>()),
|
||||
const_cast<T *>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
rope_sin_cos.data<float>(),
|
||||
const_cast<T*>(k_block_means.data<T>()),
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_blocks_per_seq,
|
||||
max_input_length);
|
||||
}
|
||||
} else {
|
||||
PD_THROW("Only supported cache_quant_type_str in ['none'].");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template <typename T, int knthreads, int moba_block_size, int kBlockMaxN, int searchtimes>
|
||||
__global__ void qk_gate_sort_decoder_kernel(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *decoder_seq_lens,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int kGqaGroupSize,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
const int bidb = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidh_kv = bidh / kGqaGroupSize;
|
||||
|
||||
if (decoder_seq_lens[bidb] == 0 || decoder_seq_lens[bidb] < use_moba_seq_limit) {
|
||||
return;
|
||||
}
|
||||
const int seq_len = (decoder_seq_lens[bidb] + moba_block_size - 1) / moba_block_size;
|
||||
|
||||
constexpr int kPackSize = kBlockMaxN / knthreads;
|
||||
|
||||
static_assert(kBlockMaxN % knthreads == 0);
|
||||
|
||||
T token_mean[kPackSize];
|
||||
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
using SrcType_f = Vec<float, kPackSize>;
|
||||
using SrcType_i = Vec<int, kPackSize>;
|
||||
|
||||
SrcType src;
|
||||
SrcType_f src_f;
|
||||
SrcType_i select_idx;
|
||||
|
||||
select_idx.set_zero();
|
||||
|
||||
const int load_offset = bidb * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
|
||||
|
||||
src.load_from(qk_gate_weight + load_offset);
|
||||
|
||||
float max_global = -FLT_MAX;
|
||||
float min_global = FLT_MAX;
|
||||
|
||||
const int data_len = seq_len - tidx * kPackSize;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (i < data_len) {
|
||||
src_f.data.elt[i] = float(src.data.elt[i]);
|
||||
min_global = min(min_global, src_f.data.elt[i]);
|
||||
} else {
|
||||
src_f.data.elt[i] = -FLT_MAX;
|
||||
}
|
||||
max_global = max(max_global, src_f.data.elt[i]);
|
||||
}
|
||||
|
||||
|
||||
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
|
||||
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
|
||||
|
||||
|
||||
float right_limit = max_global;
|
||||
float left_limit = min_global;
|
||||
|
||||
float mid_limit;
|
||||
int count;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < searchtimes; i++) {
|
||||
mid_limit = (left_limit + right_limit) * 0.5f;
|
||||
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
|
||||
if (count < top_k_left) {
|
||||
right_limit = mid_limit;
|
||||
} else if (count > top_k_right) {
|
||||
left_limit = mid_limit;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int store_idx = bidb * kv_head_num * kBlockMaxN + bidh_kv * kBlockMaxN + tidx * kPackSize;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (src_f.data.elt[i] >= mid_limit) {
|
||||
qk_gate_topk_idx[store_idx + i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (tidx == 0) {
|
||||
qk_gate_topk_idx[store_idx] = 1;
|
||||
qk_gate_topk_idx[store_idx + seq_len - 1] = 1;
|
||||
qk_gate_topk_idx[store_idx + seq_len - 2] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <int kBlockMaxN, int moba_block_size, typename T>
|
||||
void qk_gate_sort_decoder(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *decoder_seq_lens,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int batch_size,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int gqa_group_size = head_num / kv_head_num;
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
const int knthreads = kBlockMaxN / kPackSize;
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = batch_size;
|
||||
grid_dims.y = head_num;
|
||||
const int searchtimes = 6;
|
||||
|
||||
constexpr auto kernel = qk_gate_sort_decoder_kernel<T, knthreads, moba_block_size, kBlockMaxN, searchtimes>;
|
||||
|
||||
kernel<<<grid_dims, knthreads, 0, 0>>>(
|
||||
qk_gate_weight,
|
||||
qk_gate_topk_idx,
|
||||
decoder_seq_lens,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
gqa_group_size,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchQkSortDecoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
|
||||
const int batch_size = seq_len_decoder.dims()[0];
|
||||
paddle::Tensor qk_gate_topk_idx = paddle::empty({batch_size, kv_head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
|
||||
|
||||
qk_gate_sort_decoder<kMaxN, kMobaBlockSize, T>(
|
||||
qk_gate_weight.data<T>(),
|
||||
qk_gate_topk_idx.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit,
|
||||
qk_gate_weight.stream()
|
||||
);
|
||||
|
||||
return {qk_gate_topk_idx};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> QkSortDecoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortDecoder<phi::dtype::float16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit)
|
||||
);
|
||||
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortDecoder<phi::dtype::bfloat16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_qk_sort_decoder)
|
||||
.Inputs({
|
||||
"qk_gate_weight",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder"})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"top_k_left: int",
|
||||
"top_k_right: int",
|
||||
"use_moba_seq_limit: int"})
|
||||
.Outputs({"qk_gate_topk_idx"})
|
||||
.SetKernelFn(PD_KERNEL(QkSortDecoder));
|
||||
143
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h
Normal file
143
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/kernel_traits.h
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct moba_encoder_attn_params {
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void * __restrict__ o_ptr;
|
||||
int * __restrict__ cu_seq_q;
|
||||
int * __restrict__ cu_seq_k;
|
||||
int * __restrict__ qk_gate_topk_idx;
|
||||
int * __restrict__ seq_len_encoder;
|
||||
int * __restrict__ cu_seq_q_pack;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_seq_q;
|
||||
int max_seq_k;
|
||||
int batch_size;
|
||||
int gqa_group_size;
|
||||
float scale_softmax_log2;
|
||||
int use_moba_seq_limit;
|
||||
};
|
||||
|
||||
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
|
||||
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
|
||||
struct SharedStorageQKVO {
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
union {
|
||||
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
cutlass::arch::ClusterTransactionBarrier barrier_Q;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, int kMaxN_, bool UseMoba_, typename elem_type=cutlass::half_t>
|
||||
struct moba_encoder_attn_kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int32_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int UseMoba = UseMoba_;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static constexpr int kMaxN = kMaxN_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
|
||||
static constexpr int kStages = kStages_;
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
using TiledMma0 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
using TiledMma1 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutV =
|
||||
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
|
||||
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
|
||||
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
|
||||
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyOValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
TiledCopyOAtom{},
|
||||
TiledCopyOThrLayout{}, // Thr layout
|
||||
TiledCopyOValLayout{} // Val layout
|
||||
));
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
};
|
||||
473
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp
Normal file
473
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/mainloop_attn.hpp
Normal file
@@ -0,0 +1,473 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
enum class AttnNamedBarriers {
|
||||
QueryEmpty = 0,
|
||||
ValueEmpty = 1,
|
||||
TileCountSmemEmpty = 2,
|
||||
TileCountSmemFull = 3,
|
||||
WarpSchedulerWG1 = 4,
|
||||
WarpSchedulerWG2 = 5,
|
||||
WarpSchedulerWG3 = 6,
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopAttn {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
|
||||
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
|
||||
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
|
||||
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutV = SmemLayoutK;
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVt =
|
||||
decltype(cute::composition(SmemLayoutV{},
|
||||
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
|
||||
using SmemLayoutO = typename Ktraits::SmemLayoutO;
|
||||
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
|
||||
|
||||
using TMA_Q = decltype(make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{})); // no mcast for Q
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
|
||||
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
Element const* ptr_Q;
|
||||
LayoutT layout_Q;
|
||||
Element const* ptr_K;
|
||||
LayoutT layout_K;
|
||||
Element const* ptr_V;
|
||||
LayoutT layout_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutT layout_Q;
|
||||
LayoutT layout_K;
|
||||
LayoutT layout_V;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
|
||||
TMA_Q tma_load_Q = make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
mQ,
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{}); // no mcast for Q
|
||||
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
|
||||
TMA_KV tma_load_K = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mK,
|
||||
SmemLayoutK{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
|
||||
TMA_KV tma_load_V = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mV,
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.layout_Q, args.layout_K, args.layout_V,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
|
||||
tma_load_Q, tma_load_K, tma_load_V,
|
||||
args.softmax_scale_log2};
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(
|
||||
const MTensor &m_tensor,
|
||||
const Shape &tile_shape,
|
||||
const int *cu_seq_len,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int actual_seq_len) const {
|
||||
auto g_offset = local_tile(
|
||||
m_tensor(_, _, bidh),
|
||||
cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(cu_seq_len[bidb], _0{}));
|
||||
auto g_sequence = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
|
||||
g_offset.stride()
|
||||
));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
|
||||
template <bool UseMoba, typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_write_k,
|
||||
PipelineState& smem_pipe_write_v,
|
||||
SharedStorage &shared_storage,
|
||||
const int *qk_gate_topk_idx,
|
||||
const int n_block_max,
|
||||
const int m_block,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
|
||||
Tensor gQ = get_local_tile_tensor(
|
||||
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
|
||||
Tensor gK = get_local_tile_tensor(
|
||||
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
Tensor gV = get_local_tile_tensor(
|
||||
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
|
||||
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
|
||||
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
|
||||
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
|
||||
|
||||
uint16_t mcast_mask_kv = 0;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
||||
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
|
||||
}
|
||||
|
||||
|
||||
if (lane_predicate) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
}
|
||||
|
||||
if (lane_predicate) {
|
||||
int idx = 0;
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; ) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
int pre_idx = 1;
|
||||
if constexpr (UseMoba) {
|
||||
pre_idx = qk_gate_topk_idx[idx];
|
||||
}
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - pre_idx), tKsK(_, smem_pipe_write_k.index()));
|
||||
|
||||
++smem_pipe_write_k;
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
n_block -= pre_idx;
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
if (lane_predicate) {
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
warp_scheduler_barrier_sync() {
|
||||
if constexpr (UseSchedulerBarrier) {
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
mma_init() {
|
||||
if constexpr (!UseSchedulerBarrier) { return; }
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if (cutlass::canonical_warp_group_idx() > 1) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
|
||||
}
|
||||
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
|
||||
if (cutlass::canonical_warp_group_idx() > 2) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
warp_scheduler_barrier_arrive() {
|
||||
if constexpr (!UseSchedulerBarrier) { return; }
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
||||
} else {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <bool UseMoba, typename SharedStorage, typename FrgTensorO, typename Softmax>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_read_k,
|
||||
PipelineState& smem_pipe_read_v,
|
||||
FrgTensorO& tOrO,
|
||||
Softmax& softmax,
|
||||
const int *qk_gate_topk_idx,
|
||||
const int n_block_max,
|
||||
const int thread_idx,
|
||||
const int m_block,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k,
|
||||
SharedStorage& shared_storage) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
|
||||
|
||||
typename Ktraits::TiledMma0 tiled_mma0;
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
|
||||
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMma0.partition_fragment_B(sK);
|
||||
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
|
||||
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<0>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k);
|
||||
++smem_pipe_read_k;
|
||||
|
||||
auto col_limit_causal = [&](int row, int n_block) {
|
||||
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
|
||||
};
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
||||
Tensor tScS = threadMma0.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if (int(get<1>(tScS(i))) >=
|
||||
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
|
||||
tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
|
||||
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
||||
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
||||
clear(scores_scale);
|
||||
|
||||
int idx = 0;
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; ) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
||||
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
if constexpr (UseMoba) {
|
||||
n_block -= qk_gate_topk_idx[idx];
|
||||
idx += 1;
|
||||
} else {
|
||||
n_block -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v);
|
||||
++smem_pipe_read_v;
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
}
|
||||
|
||||
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
FrgTensorO const& tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
const int o_head_stride,
|
||||
const int real_seq,
|
||||
T * out_ptr) {
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<Element>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(AttnNamedBarriers::ValueEmpty) /*id*/);
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
|
||||
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(o_head_stride, _1{}));
|
||||
|
||||
GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
|
||||
|
||||
if (real_seq >= kBlockM) {
|
||||
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
# include "cutlass/util/cublas_wrappers.hpp"
|
||||
#endif
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_attn.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
template <int kHeadDim>
|
||||
auto get_gmem_layout(int token_num, int head_num) {
|
||||
return make_layout(
|
||||
make_shape(token_num, kHeadDim, head_num),
|
||||
make_stride(head_num * kHeadDim, _1{}, kHeadDim));
|
||||
}
|
||||
|
||||
template <typename Ktraits>
|
||||
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
|
||||
moba_encoder_attention_kernel(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT moba_encoder_attn_params const data_params) {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
using SoftType = ElementAccum;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
constexpr int kMaxN = Ktraits::kMaxN;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
|
||||
const int seq_len_q = data_params.seq_len_encoder[bidb];
|
||||
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
|
||||
|
||||
|
||||
if (seq_len_q == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
__align__(16) __shared__ int qk_gate_topk_idx[kMaxN];
|
||||
const int *qk_gate_idx_cur_offset = data_params.qk_gate_topk_idx + data_params.cu_seq_q_pack[bidb] / kBlockM * data_params.head_num * kMaxN + (m_block * data_params.head_num + bidh) * kMaxN;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < kMaxN / 4; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
|
||||
reinterpret_cast<int4*>(qk_gate_topk_idx)[i] = reinterpret_cast<const int4*>(qk_gate_idx_cur_offset)[i];
|
||||
}
|
||||
|
||||
|
||||
const int n_block_max = min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN), cute::ceil_div(seq_len_k, kBlockN));
|
||||
|
||||
if (m_block * kBlockM >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
shared_storage.barrier_Q.init(1);
|
||||
}
|
||||
|
||||
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
|
||||
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
|
||||
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
collective_mainloop.load<Ktraits::UseMoba>(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_write_k,
|
||||
smem_pipe_write_v,
|
||||
shared_storage,
|
||||
qk_gate_topk_idx,
|
||||
n_block_max,
|
||||
m_block,
|
||||
bidh,
|
||||
bidb,
|
||||
data_params.cu_seq_q,
|
||||
data_params.cu_seq_k,
|
||||
seq_len_q,
|
||||
seq_len_k);
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
|
||||
collective_mainloop.mma_init();
|
||||
|
||||
PipelineState smem_pipe_read_k, smem_pipe_read_v;
|
||||
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
|
||||
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
|
||||
|
||||
collective_mainloop.mma<Ktraits::UseMoba>(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_read_k,
|
||||
smem_pipe_read_v,
|
||||
tOrO,
|
||||
softmax,
|
||||
qk_gate_topk_idx,
|
||||
n_block_max,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
m_block,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
shared_storage);
|
||||
|
||||
const int o_head_stride = data_params.head_num * kHeadDim;
|
||||
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
|
||||
|
||||
const int real_seq = seq_len_q - m_block * kBlockM;
|
||||
|
||||
collective_mainloop.store<NumMmaThreads>(
|
||||
mainloop_params,
|
||||
tOrO,
|
||||
shared_storage,
|
||||
tiled_mma1,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
o_head_stride,
|
||||
real_seq,
|
||||
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_moba_decoder_attn(moba_encoder_attn_params ¶ms, cudaStream_t stream) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(params.q_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_q * params.batch_size, params.head_num),
|
||||
static_cast<Element const*>(params.k_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
|
||||
static_cast<Element const*>(params.v_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_k * params.batch_size, params.kv_head_num),
|
||||
params.scale_softmax_log2
|
||||
});
|
||||
|
||||
int num_blocks_m = cutlass::ceil_div(params.max_seq_q, Kernel_traits::kBlockM);
|
||||
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)moba_encoder_attention_kernel<Kernel_traits>;
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = num_blocks_m;
|
||||
grid_dims.y = params.head_num;
|
||||
grid_dims.z = params.batch_size;
|
||||
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
|
||||
}
|
||||
|
||||
|
||||
template <int kBlockM, int kBlockN, int kMaxN, typename InputType>
|
||||
void run_moba_encoder_attn_hdim128(moba_encoder_attn_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
constexpr static int Headdim = 128;
|
||||
constexpr static int kNWarps = kBlockM / 16 + 4;
|
||||
constexpr static int kStages = 2;
|
||||
|
||||
using Ktraits = moba_encoder_attn_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, kMaxN, true, InputType>;
|
||||
run_moba_decoder_attn<Ktraits>(params, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DispatchMobaEncoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& out,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int batch_size,
|
||||
const int max_input_length) {
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
|
||||
moba_encoder_attn_params params;
|
||||
memset(¶ms, 0, sizeof(moba_encoder_attn_params));
|
||||
|
||||
params.q_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>()));
|
||||
params.k_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>()));
|
||||
params.v_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>()));
|
||||
params.o_ptr = reinterpret_cast<cute_type*>(const_cast<T*>(out.data<T>()));
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_seq_q = max_seq_q;
|
||||
params.max_seq_k = max_seq_k;
|
||||
params.batch_size = batch_size;
|
||||
params.gqa_group_size = head_num / kv_head_num;
|
||||
constexpr float kLog2e = 1.4426950408889634074;
|
||||
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
|
||||
params.qk_gate_topk_idx = const_cast<int*>(qk_gate_topk_idx.data<int>());
|
||||
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.cu_seq_q_pack = const_cast<int*>(cu_seq_q_pack.data<int>());
|
||||
|
||||
run_moba_encoder_attn_hdim128<kBlockM, kBlockN, kMaxN, cute_type>(params, out.stream());
|
||||
}
|
||||
|
||||
void MobaEncoderAttn(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& qk_gate_topk_idx,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& out,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length) {
|
||||
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
return
|
||||
DispatchMobaEncoderAttn<phi::dtype::float16>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
qk_gate_topk_idx,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
out,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
batch_size,
|
||||
max_input_length);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return
|
||||
DispatchMobaEncoderAttn<phi::dtype::bfloat16>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
qk_gate_topk_idx,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
out,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
batch_size,
|
||||
max_input_length);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moba_encoder_attn)
|
||||
.Inputs({
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"qk_gate_topk_idx",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"cu_seq_q_pack",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"out"})
|
||||
.Attrs({
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_input_length: int"})
|
||||
.Outputs({"attn_out"})
|
||||
.SetInplaceMap({{"out", "attn_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MobaEncoderAttn));
|
||||
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template <typename T, int kBlockSize, int kHeadDim>
|
||||
__global__ void write_encoder_cachekv_c16(
|
||||
const T * k_input,
|
||||
const T * v_input,
|
||||
const int * cu_seq_k,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
T * cache_k,
|
||||
T * cache_v,
|
||||
const int * block_tables,
|
||||
const int kv_head_num,
|
||||
const int max_blocks_per_seq) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
const int block_idx = blockIdx.x * kBlockSize;
|
||||
int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
const int row_idx = tidx / (kHeadDim / kPackSize);
|
||||
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
|
||||
const int seq_len = seq_len_encoder[bidb];
|
||||
|
||||
if (seq_len == 0) return;
|
||||
|
||||
const int ramian_tokens = seq_len - block_idx;
|
||||
|
||||
const int32_t *block_table_now = block_tables + bidb * max_blocks_per_seq;
|
||||
const uint32_t physical_block_number = block_table_now[blockIdx.x + seq_len_decoder[bidb] / kBlockSize];
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
T * cache = cache_k + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(k_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bidh -= kv_head_num;
|
||||
const int base_load_idx = (block_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
T * cache = cache_v + physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(cache + i * kHeadDim) = *reinterpret_cast<const float4*>(v_input + base_load_idx + i * kv_head_num * kHeadDim);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void MobaEncoderAttnWriteCacheKv(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_q,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
constexpr int kThreads = 128;
|
||||
constexpr int kHeadDim = 128;
|
||||
assert(kHeadDim == head_dim);
|
||||
constexpr int kBlockSize = 64;
|
||||
const int batch_size = block_tables.dims()[0];
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
if (cache_quant_type_str == "none") {
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_q + kBlockSize - 1) / kBlockSize;
|
||||
grid_dims.y = kv_head_num * 2;
|
||||
grid_dims.z = batch_size;
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(v_input.data<T>()),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T*>(cache_k.data<T>()),
|
||||
const_cast<T*>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
kv_head_num,
|
||||
max_blocks_per_seq);
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
write_encoder_cachekv_c16<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, k_input.stream()>>>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(v_input.data<T>()),
|
||||
cu_seq_k.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
const_cast<T*>(cache_k.data<T>()),
|
||||
const_cast<T*>(cache_v.data<T>()),
|
||||
block_tables.data<int>(),
|
||||
kv_head_num,
|
||||
max_blocks_per_seq);
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"Quantized cache not implemented for cache_quant_type = %s", cache_quant_type_str.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_encoder_attn_write_cache_kv)
|
||||
.Inputs({
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cache_k",
|
||||
"cache_v",
|
||||
"block_tables",
|
||||
paddle::Optional("cache_k_quant_scale"),
|
||||
paddle::Optional("cache_v_quant_scale"),
|
||||
paddle::Optional("cache_k_dequant_scale"),
|
||||
paddle::Optional("cache_v_dequant_scale"),
|
||||
paddle::Optional("cache_k_zero_points"),
|
||||
paddle::Optional("cache_v_zero_points")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_q: int",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({"cache_k_out", "cache_v_out"})
|
||||
.SetInplaceMap({{"cache_k", "cache_k_out"},
|
||||
{"cache_v", "cache_v_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MobaEncoderAttnWriteCacheKv));
|
||||
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
|
||||
template <typename T, int knthreads, int moba_block_size, int kBlockM, int kBlockMaxN, int searchtimes>
|
||||
__global__ void qk_gate_sort_encoder_kernel(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int* cu_seq_q,
|
||||
const int* cu_seq_k,
|
||||
const int* cu_seq_q_pack,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int kGqaGroupSize,
|
||||
const int top_k_left,
|
||||
const int top_k_right) {
|
||||
|
||||
const int bidt = blockIdx.x * kBlockM;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kPackSize = kBlockMaxN / knthreads;
|
||||
|
||||
static_assert(kBlockMaxN % knthreads == 0);
|
||||
|
||||
const int seq_len_q = seq_len_encoder[bidb];
|
||||
|
||||
if (seq_len_q == 0 || bidt >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq_len_k = (bidt + kBlockM + seq_len_decoder[bidb]);
|
||||
|
||||
const int seq_len_moba = seq_len_k / moba_block_size;
|
||||
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
using SrcType_f = Vec<float, kPackSize>;
|
||||
using SrcType_i = Vec<int, kPackSize>;
|
||||
|
||||
SrcType src;
|
||||
SrcType_f src_f;
|
||||
|
||||
SrcType_i select_idx;
|
||||
|
||||
select_idx.set_zero();
|
||||
|
||||
const int store_idx = cu_seq_q_pack[bidb] / kBlockM * head_num * kBlockMaxN + bidh * kBlockMaxN + blockIdx.x * head_num * kBlockMaxN + tidx * kPackSize;
|
||||
|
||||
if (seq_len_k < use_moba_seq_limit) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
select_idx.data.elt[i] = 1;
|
||||
}
|
||||
select_idx.store_to(qk_gate_topk_idx + store_idx);
|
||||
return;
|
||||
}
|
||||
|
||||
const int load_offset = (cu_seq_q[bidb] + bidt) * head_num * kBlockMaxN + bidh * kBlockMaxN + tidx * kPackSize;
|
||||
const int data_len = seq_len_moba - tidx * kPackSize;
|
||||
|
||||
#pragma unroll
|
||||
for (int t = 0; t < kBlockM; t++) {
|
||||
if (bidt + t >= seq_len_q) {
|
||||
break;
|
||||
}
|
||||
src.load_from(qk_gate_weight + load_offset + t * head_num * kBlockMaxN);
|
||||
float min_global = FLT_MAX;
|
||||
float max_global = -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (i < data_len) {
|
||||
src_f.data.elt[i] = float(src.data.elt[i]);
|
||||
min_global = min(min_global, src_f.data.elt[i]);
|
||||
} else {
|
||||
src_f.data.elt[i] = -FLT_MAX;
|
||||
}
|
||||
max_global = max(max_global, src_f.data.elt[i]);
|
||||
}
|
||||
|
||||
max_global = BlockAllReduce<float, MaxOp<float>, knthreads>(max_global);
|
||||
min_global = BlockAllReduce<float, MinOp<float>, knthreads>(min_global);
|
||||
|
||||
float right_limit = max_global;
|
||||
float left_limit = min_global;
|
||||
|
||||
float mid_limit;
|
||||
int count;
|
||||
|
||||
if (right_limit == left_limit) {
|
||||
mid_limit = (left_limit + right_limit) * 0.5f;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < searchtimes; i++) {
|
||||
mid_limit = (left_limit + right_limit) * 0.5f;
|
||||
count = get_data_count<kPackSize, knthreads>(src_f.data.elt, mid_limit);
|
||||
if (count < top_k_left) {
|
||||
right_limit = mid_limit;
|
||||
} else if (count > top_k_right) {
|
||||
left_limit = mid_limit;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i++) {
|
||||
if (src_f.data.elt[i] >= mid_limit) {
|
||||
select_idx.data.elt[i] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tidx == 0) {
|
||||
select_idx.data.elt[0] = 1;
|
||||
}
|
||||
|
||||
__align__(16) __shared__ int qk_gate_mem[kBlockMaxN];
|
||||
__align__(16) __shared__ int qk_continue_idx_mem[kBlockMaxN];
|
||||
select_idx.store_to(qk_gate_mem + tidx * kPackSize);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx == 0) {
|
||||
int cur_idx = 0;
|
||||
int idx = -1;
|
||||
const int last_idx = seq_len_moba - 1;
|
||||
while (last_idx + idx >= 0 && qk_gate_mem[last_idx + idx] == 0) {
|
||||
idx--;
|
||||
}
|
||||
qk_continue_idx_mem[cur_idx] = -idx;
|
||||
cur_idx++;
|
||||
|
||||
for (int i = last_idx - 1; i >= 0; --i) {
|
||||
if (qk_gate_mem[i] == 1) {
|
||||
int idx = -1;
|
||||
while (i + idx >= 0 && qk_gate_mem[i + idx] == 0) {
|
||||
idx--;
|
||||
}
|
||||
qk_continue_idx_mem[cur_idx] = -idx;
|
||||
cur_idx++;
|
||||
}
|
||||
}
|
||||
qk_continue_idx_mem[cur_idx] = 10000000;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
*reinterpret_cast<SrcType_i *>(qk_gate_topk_idx + store_idx) = reinterpret_cast<SrcType_i *>(qk_continue_idx_mem)[tidx];
|
||||
}
|
||||
|
||||
template <int kBlockM, int kMaxN, int moba_block_size, typename T>
|
||||
void qk_gate_sort_encoder(
|
||||
const T* qk_gate_weight,
|
||||
int * qk_gate_topk_idx,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int* cu_seq_q,
|
||||
const int* cu_seq_k,
|
||||
const int* cu_seq_q_pack,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int batch_size,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
cudaStream_t stream) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
|
||||
const int gqa_group_size = head_num / kv_head_num;
|
||||
const int knthreads = kMaxN / kPackSize;
|
||||
const int searchtimes = 6;
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_q + kBlockM - 1) / kBlockM;
|
||||
grid_dims.y = head_num;
|
||||
grid_dims.z = batch_size;
|
||||
|
||||
constexpr auto kernel = qk_gate_sort_encoder_kernel<T, knthreads, moba_block_size, kBlockM, kMaxN, searchtimes>;
|
||||
|
||||
kernel<<<grid_dims, knthreads, 0, stream>>>(
|
||||
qk_gate_weight,
|
||||
qk_gate_topk_idx,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
gqa_group_size,
|
||||
top_k_left,
|
||||
top_k_right);
|
||||
}
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchQkSortEncoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
|
||||
paddle::Tensor qk_gate_topk_idx = paddle::empty({q_pack_tokens.data<int>()[0] / kBlockM, head_num, kMaxN}, paddle::DataType::INT32, qk_gate_weight.place());
|
||||
|
||||
qk_gate_sort_encoder<kBlockM, kMaxN, kMobaBlockSize, cute_type>(
|
||||
reinterpret_cast<const cute_type *>(qk_gate_weight.data<T>()),
|
||||
qk_gate_topk_idx.data<int>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
cu_seq_q_pack.data<int>(),
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
qk_gate_weight.stream());
|
||||
|
||||
return {qk_gate_topk_idx};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> QkSortEncoder(
|
||||
const paddle::Tensor& qk_gate_weight,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& cu_seq_q_pack,
|
||||
const paddle::Tensor& q_pack_tokens,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int top_k_left,
|
||||
const int top_k_right,
|
||||
const int use_moba_seq_limit) {
|
||||
if (qk_gate_weight.dtype() == paddle::DataType::FLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortEncoder<phi::dtype::float16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
q_pack_tokens,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
} else if (qk_gate_weight.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return std::move(
|
||||
DispatchQkSortEncoder<phi::dtype::bfloat16>(
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
cu_seq_q_pack,
|
||||
q_pack_tokens,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
top_k_left,
|
||||
top_k_right,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_qk_sort_encoder)
|
||||
.Inputs({
|
||||
"qk_gate_weight",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"cu_seq_q_pack",
|
||||
"q_pack_tokens"})
|
||||
.Attrs({
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"top_k_left: int",
|
||||
"top_k_right: int",
|
||||
"use_moba_seq_limit: int"})
|
||||
.Outputs({"qk_gate_topk_idx"})
|
||||
.SetKernelFn(PD_KERNEL(QkSortEncoder));
|
||||
194
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp
Normal file
194
custom_ops/gpu_ops/moba_attn/moba_encoder_attn/softmax.hpp
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "../moba_attn_utils.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
|
||||
}
|
||||
|
||||
__forceinline__ __device__ __half2 half_exp(__half2 x) {
|
||||
uint32_t tmp_out, tmp_in;
|
||||
tmp_in = reinterpret_cast<uint32_t&>(x);
|
||||
asm ("ex2.approx.f16x2 %0, %1;\n"
|
||||
: "=r"(tmp_out)
|
||||
: "r"(tmp_in));
|
||||
__half2 out = reinterpret_cast<__half2&>(tmp_out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const float max_scaled = max(mi) * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
CUTLASS_DEVICE Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<bool Is_first, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT scores_scale;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.0f / sum;
|
||||
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
|
||||
scores_scale(mi) = inv_sum;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
@@ -0,0 +1,288 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
template <typename T, int kBlockSize, int kHeadDim>
|
||||
__global__ void get_kv_from_cache_c16_kernel(
|
||||
T * k_input,
|
||||
T * v_input,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
const int * cu_seq_k,
|
||||
const T * cache_k,
|
||||
const T * cache_v,
|
||||
const int * block_tables,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int batch_size,
|
||||
const int max_input_length,
|
||||
const int max_blocks_per_seq) {
|
||||
|
||||
const int block_idx = blockIdx.x;
|
||||
int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int seq_len = seq_len_decoder[bidb] + seq_len_encoder[bidb];
|
||||
const int tidx = threadIdx.x;
|
||||
const int base_token_idx = block_idx * kBlockSize;
|
||||
|
||||
if (base_token_idx >= seq_len || seq_len_encoder[bidb] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
|
||||
const int row_idx = tidx / (kHeadDim / kPackSize);
|
||||
const int col_idx = tidx % (kHeadDim / kPackSize) * kPackSize;
|
||||
const int physical_block_number = block_tables[bidb * max_blocks_per_seq + block_idx];
|
||||
|
||||
|
||||
const int ramian_tokens = seq_len - base_token_idx;
|
||||
|
||||
if (bidh < kv_head_num) {
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(k_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_k + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bidh -= kv_head_num;
|
||||
const int cache_offset = physical_block_number * kv_head_num * kBlockSize * kHeadDim + bidh * kBlockSize * kHeadDim + col_idx;
|
||||
const int base_store_idx = (base_token_idx + cu_seq_k[bidb]) * kv_head_num * kHeadDim + bidh * kHeadDim + col_idx;
|
||||
#pragma unroll
|
||||
for (int i = row_idx; i < kBlockSize; i += 128 / (kHeadDim / kPackSize)) {
|
||||
if (i < ramian_tokens) {
|
||||
*reinterpret_cast<float4*>(v_input + base_store_idx + i * kv_head_num * kHeadDim) = *reinterpret_cast<const float4*>(cache_v + cache_offset + i * kHeadDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void get_kv_from_cache(
|
||||
T * k_input,
|
||||
T * v_input,
|
||||
const int * seq_len_encoder,
|
||||
const int * seq_len_decoder,
|
||||
const int * cu_seq_k,
|
||||
const void * cache_k,
|
||||
const void * cache_v,
|
||||
const int * block_tables,
|
||||
const T * cache_k_dequant_scale,
|
||||
const T * cache_v_dequant_scale,
|
||||
const T * cache_k_zero_points,
|
||||
const T * cache_v_zero_points,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_k,
|
||||
const int batch_size,
|
||||
const int max_input_length,
|
||||
const int max_blocks_per_seq,
|
||||
const std::string &cache_quant_type_str,
|
||||
cudaStream_t stream) {
|
||||
|
||||
constexpr int kThreads = 128;
|
||||
constexpr int kHeadDim = 128;
|
||||
assert(kHeadDim == head_dim);
|
||||
constexpr int kBlockSize = 64;
|
||||
if (cache_quant_type_str == "none") {
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_k + kBlockSize - 1) / kBlockSize;
|
||||
grid_dims.y = kv_head_num * 2;
|
||||
grid_dims.z = batch_size;
|
||||
get_kv_from_cache_c16_kernel<T, kBlockSize, kHeadDim><<<grid_dims, kThreads, 0, stream>>>(
|
||||
k_input,
|
||||
v_input,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_k,
|
||||
reinterpret_cast<const T*>(cache_k),
|
||||
reinterpret_cast<const T*>(cache_v),
|
||||
block_tables,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
batch_size,
|
||||
max_input_length,
|
||||
max_blocks_per_seq);
|
||||
} else {
|
||||
PD_THROW("Only supported cache_quant_type_str in ['none'].");
|
||||
}
|
||||
}
|
||||
|
||||
void GetKVFromCache(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cache_k,
|
||||
const paddle::Tensor& cache_v,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
get_kv_from_cache<cute_type>(
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
cache_k.data(),
|
||||
cache_v.data(),
|
||||
block_tables.data<int>(),
|
||||
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
|
||||
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_k,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
block_tables.dims()[1],
|
||||
cache_quant_type_str,
|
||||
k_input.stream());
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
get_kv_from_cache<cute_type>(
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
cache_k.data(),
|
||||
cache_v.data(),
|
||||
block_tables.data<int>(),
|
||||
cache_k_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_v_dequant_scale ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_dequant_scale.get().data<T>())) : nullptr,
|
||||
cache_k_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_k_zero_points.get().data<T>())) : nullptr,
|
||||
cache_v_zero_points ? reinterpret_cast<cute_type*>(const_cast<T*>(cache_v_zero_points.get().data<T>())) : nullptr,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_k,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
block_tables.dims()[1],
|
||||
cache_quant_type_str,
|
||||
k_input.stream());
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void get_cur_cu_seq_len_k_kernel(
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int* __restrict__ seq_lens_decoder,
|
||||
const int* __restrict__ seq_lens_this_time,
|
||||
int* __restrict__ cu_seqlens_k,
|
||||
int* __restrict__ cu_seq_q_pack,
|
||||
int* __restrict__ q_pack_tokens,
|
||||
const int pack_size,
|
||||
const int bsz) {
|
||||
|
||||
int total_tokens = 0;
|
||||
cu_seqlens_k[0] = 0;
|
||||
cu_seq_q_pack[0] = 0;
|
||||
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int cache_len = seq_lens_decoder[bid];
|
||||
const int q_len = seq_lens_encoder[bid];
|
||||
if (q_len <= 0) {
|
||||
cache_len = 0;
|
||||
}
|
||||
total_tokens += (cache_len + q_len);
|
||||
cu_seqlens_k[bid + 1] = total_tokens;
|
||||
cu_seq_q_pack[bid + 1] = cu_seq_q_pack[bid] + (q_len + pack_size -1) / pack_size * pack_size;
|
||||
}
|
||||
q_pack_tokens[0] = cu_seq_q_pack[bsz];
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetCurCuSeqLenk(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const int pack_size) {
|
||||
auto stream = seq_lens_decoder.stream();
|
||||
auto place = seq_lens_decoder.place();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
paddle::Tensor cu_seq_q_pack = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor cu_seqlens_k = paddle::empty({bsz + 1}, paddle::DataType::INT32, place);
|
||||
paddle::Tensor q_pack_tokens = paddle::empty({1}, paddle::DataType::INT32, place);
|
||||
|
||||
get_cur_cu_seq_len_k_kernel<<<1, 1, 0, stream>>>(
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cu_seq_q_pack.data<int>(),
|
||||
q_pack_tokens.data<int>(),
|
||||
pack_size,
|
||||
bsz
|
||||
);
|
||||
|
||||
auto q_pack_tokens_cpu = q_pack_tokens.copy_to(paddle::CPUPlace(), true);
|
||||
return {cu_seq_q_pack, cu_seqlens_k, q_pack_tokens_cpu};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_kv_from_cache)
|
||||
.Inputs({
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cache_k",
|
||||
"cache_v",
|
||||
"block_tables",
|
||||
paddle::Optional("cache_k_dequant_scale"),
|
||||
paddle::Optional("cache_v_dequant_scale"),
|
||||
paddle::Optional("cache_k_zero_points"),
|
||||
paddle::Optional("cache_v_zero_points")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_input_length: int",
|
||||
"max_seq_k: int",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({"k_input_out", "v_input_out"})
|
||||
.SetInplaceMap({{"k_input", "k_input_out"},
|
||||
{"v_input", "v_input_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetKVFromCache));
|
||||
|
||||
PD_BUILD_OP(get_cur_cu_seq_len_k)
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time"})
|
||||
.Attrs({
|
||||
"pack_size: int"})
|
||||
.Outputs({"cu_seq_q_pack", "cu_seqlens_k", "q_pack_tokens"})
|
||||
.SetKernelFn(PD_KERNEL(GetCurCuSeqLenk));
|
||||
221
custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu
Normal file
221
custom_ops/gpu_ops/moba_attn/moba_process/moba_mlp_einsum.cu
Normal file
@@ -0,0 +1,221 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
|
||||
template <typename T, int moba_block_size, int kHeadDim, int kMaxN>
|
||||
__global__ void moba_mlp_einsum_kernel(
|
||||
const T * src_data,
|
||||
const T * weight_data,
|
||||
const int * seq_lens_encoder,
|
||||
const int * seq_lens_decoder,
|
||||
const int * cu_seq_k,
|
||||
T * dst_data,
|
||||
const int head_num) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(T);
|
||||
const int block_idx = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
const int tidx = threadIdx.x;
|
||||
const int lane_id = tidx % 32;
|
||||
const int warp_id = tidx / 32;
|
||||
|
||||
__align__(16) __shared__ T local_sum_mem[128 / 32 * kHeadDim];
|
||||
|
||||
const int seq_len_encoder = seq_lens_encoder[bidb];
|
||||
const int seq_len_decoder = seq_len_encoder + seq_lens_decoder[bidb];
|
||||
|
||||
const int seq_len_this_block = seq_len_decoder - block_idx * moba_block_size;
|
||||
|
||||
if (seq_len_encoder == 0 || seq_len_this_block <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
using SrcType = Vec<T, kPackSize>;
|
||||
|
||||
constexpr int tidx_per_row = kHeadDim / kPackSize;
|
||||
|
||||
const int row_idx = tidx / tidx_per_row;
|
||||
const int col_idx = tidx % tidx_per_row * kPackSize;
|
||||
|
||||
const int src_base_idx = cu_seq_k[bidb] * head_num * kHeadDim + block_idx * moba_block_size * head_num * kHeadDim + bidh * kHeadDim + row_idx * head_num * kHeadDim + col_idx;
|
||||
const int weight_base_idx = bidh * kHeadDim * moba_block_size + row_idx * kHeadDim + col_idx;
|
||||
|
||||
constexpr int step = 128 / tidx_per_row;
|
||||
|
||||
SrcType sums, src, weight;
|
||||
|
||||
sums.set_zero();
|
||||
|
||||
for (int i = 0; i < moba_block_size; i += step) {
|
||||
if (i >= seq_len_this_block) {
|
||||
break;
|
||||
}
|
||||
src.load_from(src_data + src_base_idx + i * head_num * kHeadDim);
|
||||
weight.load_from(weight_data + weight_base_idx + i * kHeadDim);
|
||||
sums.fma(src, weight);
|
||||
}
|
||||
|
||||
SrcType neighbor;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i+=2) {
|
||||
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(sums.data.elt + i), 16);
|
||||
}
|
||||
|
||||
sums.add(neighbor);
|
||||
|
||||
if (lane_id < 16) {
|
||||
sums.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
using pack_half = std::conditional_t<std::is_same<T, phi::dtype::float16>::value, __half2, nv_bfloat162>;
|
||||
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
|
||||
|
||||
if (tidx < kHeadDim / 2) {
|
||||
pack_half local_sum_half = local_sum_mem_half[tidx];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 4; i++) {
|
||||
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
|
||||
}
|
||||
local_sum_mem_half[tidx] = local_sum_half;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int store_row_id = tidx / (kHeadDim / kPackSize);
|
||||
const int store_col_id = tidx % (kHeadDim / kPackSize) * kPackSize;
|
||||
|
||||
sums.load_from(local_sum_mem + store_col_id);
|
||||
|
||||
const int base_store_idx = bidb * kMaxN * head_num * kHeadDim + (block_idx * (moba_block_size / 128) + store_row_id) * head_num * kHeadDim + bidh * kHeadDim + store_col_id;
|
||||
|
||||
if (store_row_id < moba_block_size / 128) {
|
||||
sums.store_to(dst_data + base_store_idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int kHeadDim, int kMaxN>
|
||||
void moba_mlp_einsum(
|
||||
const T * src_data,
|
||||
const T * weight_data,
|
||||
const int * seq_lens_encoder,
|
||||
const int * seq_lens_decoder,
|
||||
const int * cu_seq_k,
|
||||
T * dst_data,
|
||||
const int moba_block_size,
|
||||
const int max_seq_len,
|
||||
const int head_num,
|
||||
const int batch_size,
|
||||
cudaStream_t stream) {
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = (max_seq_len + moba_block_size - 1) / moba_block_size;
|
||||
grid_dims.y = head_num;
|
||||
grid_dims.z = batch_size;
|
||||
|
||||
if (moba_block_size == 1024) {
|
||||
moba_mlp_einsum_kernel<T, 1024, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
|
||||
src_data,
|
||||
weight_data,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seq_k,
|
||||
dst_data,
|
||||
head_num);
|
||||
} else if (moba_block_size == 128) {
|
||||
moba_mlp_einsum_kernel<T, 128, kHeadDim, kMaxN><<<grid_dims, 128, 0, stream>>>(
|
||||
src_data,
|
||||
weight_data,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seq_k,
|
||||
dst_data,
|
||||
head_num);
|
||||
} else {
|
||||
PADDLE_THROW(phi::errors::Unimplemented(
|
||||
"MobaMlpEinsum not implemented for moba_block_size = %d", moba_block_size));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MobaMlpEinsum(
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& attn_gate_weight,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_len,
|
||||
const int kv_head_num) {
|
||||
|
||||
const int kHeadDim = 128;
|
||||
const int kMaxN = 1024;
|
||||
const int moba_block_size = attn_gate_weight.dims()[1];
|
||||
const int batch_size = seq_lens_encoder.dims()[0];
|
||||
paddle::Tensor k_gate_weight = paddle::zeros({batch_size, kMaxN, kv_head_num, kHeadDim}, k_input.dtype(), k_input.place());
|
||||
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
moba_mlp_einsum<T, kHeadDim, kMaxN>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(attn_gate_weight.data<T>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int*>(cu_seq_k.data<int>()),
|
||||
k_gate_weight.data<T>(),
|
||||
moba_block_size,
|
||||
max_seq_len,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
k_input.stream()
|
||||
);
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
moba_mlp_einsum<T, kHeadDim, kMaxN>(
|
||||
const_cast<T*>(k_input.data<T>()),
|
||||
const_cast<T*>(attn_gate_weight.data<T>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int*>(cu_seq_k.data<int>()),
|
||||
k_gate_weight.data<T>(),
|
||||
moba_block_size,
|
||||
max_seq_len,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
k_input.stream()
|
||||
);
|
||||
}
|
||||
return {k_gate_weight};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_mlp_einsum)
|
||||
.Inputs({
|
||||
"k_input",
|
||||
"attn_gate_weight",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"cu_seq_k"})
|
||||
.Attrs({
|
||||
"max_seq_len: int",
|
||||
"kv_head_num: int"})
|
||||
.Outputs({"k_gate"})
|
||||
.SetKernelFn(PD_KERNEL(MobaMlpEinsum));
|
||||
465
custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu
Normal file
465
custom_ops/gpu_ops/moba_attn/moba_process/moba_qk_gemm.cu
Normal file
@@ -0,0 +1,465 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
|
||||
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, int kHeadDim, bool is_split_kv>
|
||||
__global__ void qk_gemm_kernel(
|
||||
const input_type *q_input,
|
||||
const input_type *k_gate_mean,
|
||||
input_type *qk_gate_weight,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int kGQA_groupsize) {
|
||||
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, input_type,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, input_type, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomQK = decltype(
|
||||
cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, input_type,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
|
||||
using SmemLayoutQK = decltype(tile_to_shape(SmemLayoutAtomQK{}, select<0, 1>(TileShape_MNK{})));
|
||||
|
||||
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<input_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
|
||||
using ValLayoutMNK = std::conditional_t<
|
||||
is_split_kv,
|
||||
Layout<Shape<_1,_4,_1>>,
|
||||
Layout<Shape<_4,_1,_1>>
|
||||
>;
|
||||
|
||||
using PermutationMNK = std::conditional_t<
|
||||
is_split_kv,
|
||||
Tile<_16,_64,_16>,
|
||||
Tile<_64,_16,_16>
|
||||
>;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
MMA_Atom_Arch,
|
||||
ValLayoutMNK,
|
||||
PermutationMNK>;
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, input_type>;
|
||||
using SmemCopyAtomQK = Copy_Atom<cute::SM90_U32x4_STSM_N, input_type>;
|
||||
|
||||
constexpr int kNThreads = 128;
|
||||
constexpr int kThreadPerValue = 16 / sizeof(input_type);
|
||||
constexpr int kThreadsPerRow = kHeadDim / kThreadPerValue;
|
||||
constexpr int kThreadsPerRowQK = kBlockN / kThreadPerValue;
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape <Int<kNThreads / kThreadsPerRow>, Int<kThreadsPerRow>>,
|
||||
Stride<Int<kThreadsPerRow>, _1>>;
|
||||
|
||||
using GmemTiledCopy = decltype(
|
||||
make_tiled_copy(Copy_Atom<
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, input_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
using GmemLayoutAtomQK = Layout<
|
||||
Shape <Int<kNThreads / kThreadsPerRowQK>, Int<kThreadsPerRowQK>>,
|
||||
Stride<Int<kThreadsPerRowQK>, _1>>;
|
||||
|
||||
using GmemTiledCopyQK = decltype(
|
||||
make_tiled_copy(Copy_Atom<
|
||||
UniversalCopy<cutlass::uint128_t>, input_type>{},
|
||||
GmemLayoutAtomQK{},
|
||||
Layout<Shape<_1, Int<kThreadPerValue>>>{}));
|
||||
|
||||
int mn_block = blockIdx.x;
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidh = blockIdx.z;
|
||||
const int bidh_k = bidh / kGQA_groupsize;
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int seq_len_q = seq_len_encoder[bidb];
|
||||
const int seq_len_k = seq_len_decoder[bidb];
|
||||
const int seq_len_qk = seq_len_q + seq_len_k;
|
||||
|
||||
int q_head_stride;
|
||||
const int k_head_stride = kv_head_num * kHeadDim;
|
||||
int qk_head_stride;
|
||||
int offset_q;
|
||||
int offset_k;
|
||||
int offset_qk;
|
||||
int remain_q_seq;
|
||||
|
||||
if constexpr (is_split_kv) {
|
||||
if (seq_len_k < use_moba_seq_limit || seq_len_k == 0) {
|
||||
return;
|
||||
}
|
||||
mn_block *= kBlockN;
|
||||
q_head_stride = kHeadDim;
|
||||
qk_head_stride = kMaxN;
|
||||
if (mn_block >= (seq_len_k + kMobaBlockSize - 1) / kMobaBlockSize) {
|
||||
return;
|
||||
}
|
||||
offset_q = cu_seq_q[bidb] * head_num * kHeadDim + bidh * kGQA_groupsize * kHeadDim;
|
||||
offset_k = (bidb * kMaxN + mn_block) * k_head_stride + bidh * kHeadDim;
|
||||
offset_qk = bidb * head_num * kMaxN + bidh * kGQA_groupsize * kMaxN + mn_block;
|
||||
remain_q_seq = kGQA_groupsize;
|
||||
} else {
|
||||
if (seq_len_q == 0 || seq_len_qk < use_moba_seq_limit) {
|
||||
return;
|
||||
}
|
||||
q_head_stride = head_num * kHeadDim;
|
||||
qk_head_stride = head_num * kMaxN;
|
||||
mn_block *= kBlockM;
|
||||
if (mn_block >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
offset_q = (cu_seq_q[bidb] + mn_block) * q_head_stride + bidh * kHeadDim;
|
||||
offset_k = bidb * kMaxN * k_head_stride + bidh_k * kHeadDim;
|
||||
offset_qk = (cu_seq_q[bidb] + mn_block) * qk_head_stride + bidh * kMaxN;
|
||||
remain_q_seq = seq_len_q - mn_block;
|
||||
}
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(q_input + offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(q_head_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(k_gate_mean + offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(k_head_stride, _1{}));
|
||||
Tensor gQK = make_tensor(make_gmem_ptr(qk_gate_weight + offset_qk),
|
||||
Shape<Int<kBlockM>, Int<kBlockN>>{},
|
||||
make_stride(qk_head_stride, _1{}));
|
||||
|
||||
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<input_type *>(smem_)), SmemLayoutK{});
|
||||
Tensor sQ = make_tensor(sK.data() + size(sK), SmemLayoutQ{});
|
||||
Tensor sQK = make_tensor(sK.data() + size(sK), SmemLayoutQK{});
|
||||
|
||||
auto gmem_tiled_copy = GmemTiledCopy{};
|
||||
auto gmem_tiled_copy_qk = GmemTiledCopyQK{};
|
||||
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx);
|
||||
auto gmem_thr_copy_qk = gmem_tiled_copy_qk.get_thread_slice(tidx);
|
||||
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy.partition_D(sQ);
|
||||
|
||||
Tensor tKgK = gmem_thr_copy.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy.partition_D(sK);
|
||||
|
||||
Tensor tQKgQK = gmem_thr_copy_qk.partition_S(gQK);
|
||||
Tensor tQKsQK = gmem_thr_copy_qk.partition_D(sQK);
|
||||
|
||||
|
||||
Tensor cQ = make_identity_tensor(make_shape(kBlockM, kHeadDim));
|
||||
Tensor tQcQ = gmem_thr_copy.partition_S(cQ);
|
||||
|
||||
Tensor cK = make_identity_tensor(make_shape(kBlockN, kHeadDim));
|
||||
Tensor tKcK = gmem_thr_copy.partition_S(cK);
|
||||
|
||||
Tensor cQK = make_identity_tensor(make_shape(kBlockM, kBlockN));
|
||||
Tensor tQKcQK = gmem_thr_copy.partition_S(cQK);
|
||||
|
||||
if (remain_q_seq >= kBlockM) {
|
||||
copy(gmem_tiled_copy, tQgQ, tQsQ, tQcQ);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy, tQgQ, tQsQ, tQcQ, remain_q_seq);
|
||||
}
|
||||
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
|
||||
|
||||
cute::cp_async_fence();
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK);
|
||||
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
|
||||
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
|
||||
auto smem_tiled_copy_QK = make_tiled_copy_C(SmemCopyAtomQK{}, tiled_mma);
|
||||
auto smem_thr_copy_QK = smem_tiled_copy_QK.get_thread_slice(tidx);
|
||||
Tensor tsQK = smem_thr_copy_QK.partition_D(sQK);
|
||||
|
||||
const int n_blocks = is_split_kv ? 1 : cute::ceil_div(cute::ceil_div(seq_len_qk, kMobaBlockSize), kBlockN);
|
||||
|
||||
#pragma unroll
|
||||
for (int n_block = 0; n_block < n_blocks; ++n_block) {
|
||||
clear(acc_s);
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block == 0) {
|
||||
gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
} else {
|
||||
gemm<true>(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K, smem_tiled_copy_Q, smem_tiled_copy_K);
|
||||
}
|
||||
if constexpr (!is_split_kv) {
|
||||
if (n_block < n_blocks - 1) {
|
||||
__syncthreads();
|
||||
tKgK.data() = tKgK.data() + kBlockN * k_head_stride;
|
||||
copy(gmem_tiled_copy, tKgK, tKsK, tKcK);
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor rS = convert_type<input_type>(acc_s);
|
||||
Tensor trQK = smem_thr_copy_QK.retile_S(rS);
|
||||
cute::copy(smem_tiled_copy_QK, trQK, tsQK);
|
||||
|
||||
__syncthreads();
|
||||
if (remain_q_seq >= kBlockM) {
|
||||
copy(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_qk, tQKsQK, tQKgQK, tQKcQK, remain_q_seq);
|
||||
}
|
||||
if constexpr (!is_split_kv) {
|
||||
__syncthreads();
|
||||
tQKgQK.data() = tQKgQK.data() + kBlockN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_type, int kBlockM, int kBlockN, int kMobaBlockSize, int kMaxN, bool is_split_kv>
|
||||
void qk_gemm(
|
||||
const input_type *q_input,
|
||||
const input_type *k_gate_mean,
|
||||
input_type *qk_gate_weight,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int use_moba_seq_limit,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int bsz,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int gqa_group_size = head_num / kv_head_num;
|
||||
|
||||
dim3 grid_dims;
|
||||
const int num_m_block = (max_seq_q + kBlockM - 1) / kBlockM;
|
||||
const int num_n_block = ((max_seq_k + kMobaBlockSize - 1) / kMobaBlockSize + kBlockN - 1) / kBlockN;
|
||||
|
||||
if (is_split_kv) {
|
||||
grid_dims.x = num_n_block;
|
||||
grid_dims.z = kv_head_num;
|
||||
} else {
|
||||
grid_dims.x = num_m_block;
|
||||
grid_dims.z = head_num;
|
||||
}
|
||||
grid_dims.y = bsz;
|
||||
|
||||
constexpr int kHeadDim = 128;
|
||||
constexpr int smemq = kBlockM * kHeadDim * sizeof(input_type);
|
||||
constexpr int smemk = kBlockN * kHeadDim * sizeof(input_type);
|
||||
constexpr int smemqk = kBlockM * kBlockN * sizeof(input_type);
|
||||
const int smem_size = smemk + max(smemq, smemqk);
|
||||
|
||||
auto kernel = &qk_gemm_kernel<input_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, kHeadDim, is_split_kv>;
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
kernel<<<grid_dims, 128, smem_size, stream>>>(
|
||||
q_input,
|
||||
k_gate_mean,
|
||||
qk_gate_weight,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
gqa_group_size);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchMobaQKGemm(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const bool is_split_kv,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
if (is_split_kv) {
|
||||
paddle::Tensor qk_gate_weight = paddle::empty({batch_size, head_num, kMaxN}, q_input.dtype(), q_input.place());
|
||||
qk_gemm<cute_type, 16, kMobaBlockSize, kMobaBlockSize, kMaxN, true>(
|
||||
reinterpret_cast<const cute_type*>(q_input.data<T>()),
|
||||
reinterpret_cast<const cute_type*>(k_block_means.data<T>()),
|
||||
reinterpret_cast<cute_type*>(qk_gate_weight.data<T>()),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
q_input.stream()
|
||||
);
|
||||
return {qk_gate_weight};
|
||||
} else {
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
const int token_num = q_input.dims()[0];
|
||||
paddle::Tensor qk_gate_weight = paddle::empty({token_num, head_num, kMaxN}, q_input.dtype(), q_input.place());
|
||||
qk_gemm<cute_type, kBlockM, kBlockN, kMobaBlockSize, kMaxN, false>(
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(q_input.data<T>())),
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
|
||||
reinterpret_cast<cute_type *>(qk_gate_weight.data<T>()),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
use_moba_seq_limit,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
batch_size,
|
||||
q_input.stream());
|
||||
return {qk_gate_weight};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MobaQKGemm(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const bool is_split_kv,
|
||||
const int use_moba_seq_limit) {
|
||||
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
return std::move(
|
||||
DispatchMobaQKGemm<phi::dtype::float16>(
|
||||
q_input,
|
||||
k_block_means,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
is_split_kv,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
return std::move(
|
||||
DispatchMobaQKGemm<phi::dtype::bfloat16>(
|
||||
q_input,
|
||||
k_block_means,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
is_split_kv,
|
||||
use_moba_seq_limit
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moba_qk_gemm)
|
||||
.Inputs({
|
||||
"q_input",
|
||||
"k_block_means",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k"})
|
||||
.Attrs({
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"is_split_kv: bool",
|
||||
"use_moba_seq_limit: int"})
|
||||
.Outputs({"qk_gate_weight"})
|
||||
.SetKernelFn(PD_KERNEL(MobaQKGemm));
|
||||
370
custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu
Normal file
370
custom_ops/gpu_ops/moba_attn/moba_process/split_qkv_and_rope.cu
Normal file
@@ -0,0 +1,370 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "moba_attn/moba_attn_utils.hpp"
|
||||
#include "moba_attn/moba_attn.h"
|
||||
|
||||
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN, int tokens_per_block, bool need_k_mean>
|
||||
__global__ void fused_block_mean_and_rope_kernel(
|
||||
const input_type *qkv_input,
|
||||
const input_type *qkv_bias,
|
||||
input_type *k_gate_mean,
|
||||
input_type *q_input,
|
||||
input_type *k_input,
|
||||
input_type *v_input,
|
||||
const float *rope_sin_cos,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int max_input_length) {
|
||||
|
||||
constexpr int kPackSize = 16 / sizeof(input_type);
|
||||
constexpr int kHeadDim = 128;
|
||||
|
||||
using src_type = Vec<input_type, kPackSize>;
|
||||
|
||||
using rope_type = Vec<float, kPackSize / 2>;
|
||||
using pack_half = std::conditional_t<std::is_same<input_type, cutlass::half_t>::value, __half2, nv_bfloat162>;
|
||||
|
||||
__align__(16) __shared__ input_type local_sum_mem[128 / 32 * kHeadDim];
|
||||
|
||||
const int bidb = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidt_q = blockIdx.z * tokens_per_block;
|
||||
const int bidt_v = blockIdx.z * tokens_per_block;
|
||||
const int bidt_k = need_k_mean ? blockIdx.z * moba_block_size : blockIdx.z * tokens_per_block;
|
||||
const int tidx = threadIdx.x;
|
||||
const int lane_id = tidx % 32;
|
||||
const int warp_id = tidx / 32;
|
||||
const int seq_len = seq_len_encoder[bidb];
|
||||
const int seq_len_start = seq_len_decoder[bidb];
|
||||
|
||||
if (seq_len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int all_head_num = head_num + 2 * kv_head_num;
|
||||
const int hidden = all_head_num * kHeadDim;
|
||||
|
||||
const int row_idx = tidx / (kHeadDim / kPackSize);
|
||||
const int col_idx = tidx % (kHeadDim / kPackSize);
|
||||
|
||||
const int bias_idx = bidh * kHeadDim + col_idx * kPackSize;
|
||||
|
||||
src_type src, src_bias;
|
||||
rope_type sin, cos;
|
||||
|
||||
const bool need_add_bias = qkv_bias != nullptr;
|
||||
|
||||
if (need_add_bias) {
|
||||
src_bias.load_from(qkv_bias + bias_idx);
|
||||
}
|
||||
|
||||
if (bidh < head_num) {
|
||||
const int cur_token = bidt_q + row_idx;
|
||||
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
|
||||
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(q_input + (cu_seq_q[bidb] + cur_token) * head_num * kHeadDim + bias_idx);
|
||||
}
|
||||
} else if (bidh < head_num + kv_head_num) {
|
||||
if constexpr (!need_k_mean) {
|
||||
const int cur_token = bidt_k + row_idx;
|
||||
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
|
||||
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * head_num * kHeadDim + bias_idx- head_num * kHeadDim);
|
||||
}
|
||||
} else {
|
||||
if (bidt_k >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
src_type local_sum;
|
||||
local_sum.set_zero();
|
||||
|
||||
const input_type* qkv = qkv_input + cu_seq_q[bidb] * hidden + bias_idx;
|
||||
|
||||
for (int i = 0; i < moba_block_size; i += tokens_per_block) {
|
||||
const int cur_token = bidt_k + i + row_idx;
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv + cur_token * hidden);
|
||||
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
const float * cos_rope = rope_sin_cos + (cur_token + seq_len_start) * (kHeadDim / 2) + col_idx * (kPackSize / 2);
|
||||
const float * sin_rope = cos_rope + max_input_length * (kHeadDim / 2);
|
||||
sin.load_from(sin_rope);
|
||||
cos.load_from(cos_rope);
|
||||
|
||||
apply_rotary_embedding<input_type, kPackSize>(src, cos, sin);
|
||||
|
||||
src.store_to(k_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - head_num * kHeadDim);
|
||||
|
||||
local_sum.add(src);
|
||||
}
|
||||
}
|
||||
|
||||
src_type neighbor;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kPackSize; i+=2) {
|
||||
*reinterpret_cast<int32_t*>(neighbor.data.elt + i) = __shfl_down_sync(0xffffffff, *reinterpret_cast<int32_t*>(local_sum.data.elt + i), 16);
|
||||
}
|
||||
|
||||
local_sum.add(neighbor);
|
||||
|
||||
if (lane_id < 16) {
|
||||
local_sum.store_to(local_sum_mem + warp_id * kHeadDim + lane_id * kPackSize);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
pack_half * local_sum_mem_half = reinterpret_cast<pack_half*>(local_sum_mem);
|
||||
|
||||
pack_half local_sum_half = local_sum_mem_half[tidx];
|
||||
|
||||
|
||||
if (tidx < kHeadDim / 2) {
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 4; i++) {
|
||||
local_sum_half += local_sum_mem_half[tidx + i * (kHeadDim / 2)];
|
||||
}
|
||||
|
||||
float inv_tokens_sum = fdividef(1.0f, min(seq_len - bidt_k, moba_block_size));
|
||||
|
||||
local_sum_half *= float_2_half2<input_type>(inv_tokens_sum);
|
||||
|
||||
const int store_mean_idx = ((bidb * kMaxN + blockIdx.z + seq_len_start / moba_block_size) * kv_head_num * kHeadDim + (bidh - head_num) * kHeadDim) / 2 + tidx;
|
||||
|
||||
reinterpret_cast<pack_half*>(k_gate_mean)[store_mean_idx] = local_sum_half;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int cur_token = bidt_v + row_idx;
|
||||
|
||||
if (cur_token < seq_len) {
|
||||
src.load_from(qkv_input + cu_seq_q[bidb] * hidden + bias_idx + cur_token * hidden);
|
||||
if (need_add_bias) {
|
||||
src.add(src_bias);
|
||||
}
|
||||
|
||||
src.store_to(v_input + (cu_seq_k[bidb] + cur_token) * kv_head_num * kHeadDim + bias_idx - (head_num + kv_head_num) * kHeadDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_type, int moba_block_size, int kBlockM, int kMaxN>
|
||||
void fused_block_mean_and_rope(
|
||||
const input_type *qkv_input,
|
||||
const input_type *qkv_bias,
|
||||
input_type *k_gate_mean,
|
||||
input_type *q_input,
|
||||
input_type *k_input,
|
||||
input_type *v_input,
|
||||
const float *rope_sin_cos,
|
||||
const int *seq_len_encoder,
|
||||
const int *seq_len_decoder,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int bsz,
|
||||
const int max_input_length,
|
||||
cudaStream_t stream) {
|
||||
|
||||
static_assert(moba_block_size >= 64, "moba_block_size must be at least 64");
|
||||
constexpr int kPackSize = 16 / sizeof(input_type);
|
||||
constexpr int kHeadDim = 128;
|
||||
constexpr int kThreads = 128;
|
||||
constexpr int tokens_per_block = kThreads / (kHeadDim / kPackSize);
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = bsz;
|
||||
grid_dims.y = head_num + 2 * kv_head_num;
|
||||
grid_dims.z = (max_seq_q + tokens_per_block - 1) / tokens_per_block;
|
||||
|
||||
if (k_gate_mean != nullptr) {
|
||||
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, true>
|
||||
<<<grid_dims, kThreads, 0, stream>>>(
|
||||
qkv_input,
|
||||
qkv_bias,
|
||||
k_gate_mean,
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
rope_sin_cos,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_input_length);
|
||||
} else {
|
||||
fused_block_mean_and_rope_kernel<input_type, moba_block_size, kBlockM, kMaxN, tokens_per_block, false>
|
||||
<<<grid_dims, kThreads, 0, stream>>>(
|
||||
qkv_input,
|
||||
qkv_bias,
|
||||
k_gate_mean,
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
rope_sin_cos,
|
||||
seq_len_encoder,
|
||||
seq_len_decoder,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
max_input_length);
|
||||
}
|
||||
}
|
||||
|
||||
void FusedBlockMeanAndRope(
|
||||
const paddle::Tensor& qkv_out,
|
||||
const paddle::Tensor& k_block_means,
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& seq_len_decoder,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_input_length,
|
||||
const int max_seq_q,
|
||||
const int max_seq_k,
|
||||
const std::string &cache_quant_type_str) {
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
constexpr int kMobaBlockSize = 128;
|
||||
constexpr int kMaxN = 1024;
|
||||
|
||||
if (k_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
|
||||
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
rotary_embs.data<float>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
qkv_out.stream());
|
||||
} else if (k_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
fused_block_mean_and_rope<cute_type, kMobaBlockSize, kBlockM, kMaxN>(
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(qkv_out.data<T>())),
|
||||
qkv_bias ? reinterpret_cast<cute_type *>(const_cast<T*>(qkv_bias.get().data<T>())) : nullptr,
|
||||
reinterpret_cast<cute_type *>(const_cast<T*>(k_block_means.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(q_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(k_input.data<T>())),
|
||||
reinterpret_cast<cute_type*>(const_cast<T*>(v_input.data<T>())),
|
||||
rotary_embs.data<float>(),
|
||||
seq_len_encoder.data<int>(),
|
||||
seq_len_decoder.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
cu_seq_k.data<int>(),
|
||||
max_seq_q,
|
||||
max_seq_k,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
seq_len_encoder.dims()[0],
|
||||
max_input_length,
|
||||
qkv_out.stream());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_OP(fused_block_mean_and_rope)
|
||||
.Inputs({
|
||||
"qkv_out",
|
||||
"k_block_means",
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"rotary_embs",
|
||||
"seq_len_encoder",
|
||||
"seq_len_decoder",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
paddle::Optional("qkv_bias")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_input_length: int",
|
||||
"max_seq_q: int",
|
||||
"max_seq_k: int",
|
||||
"cache_quant_type_str: std::string"})
|
||||
.Outputs({"q_input_out", "k_input_out", "v_input_out", "k_block_means_out"})
|
||||
.SetInplaceMap({{"q_input", "q_input_out"},
|
||||
{"k_input", "k_input_out"},
|
||||
{"v_input", "v_input_out"},
|
||||
{"k_block_means", "k_block_means_out"}})
|
||||
.SetKernelFn(PD_KERNEL(FusedBlockMeanAndRope));
|
||||
@@ -25,6 +25,66 @@
|
||||
|
||||
#include "helper.h"
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \
|
||||
switch (num_experts_per_rank) { \
|
||||
case 2: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 2; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 8: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 9: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 9; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 16: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 32: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 48: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 64: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 128: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 160: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 160; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \
|
||||
throw std::invalid_argument(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template<typename T>
|
||||
@@ -269,7 +329,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int RoundType = 1>
|
||||
template <typename T, typename OutT, int NUM_EXPERTS_PER_RANK = 8, int Kthread = 512, int RoundType = 1>
|
||||
__global__ void permute_x_kernel(const T *src_x,
|
||||
const int64_t *topk_idx,
|
||||
const float *topk_weights,
|
||||
@@ -285,9 +345,9 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
int *dst_indices,
|
||||
int *cumsum_idx_gpu,
|
||||
int64_t *token_nums_per_expert_cumsum,
|
||||
int64_t *expert_idx_per_token,
|
||||
int64_t *expert_idx_per_token, // [num_rows, moe_topk]
|
||||
float max_bound = 127.0,
|
||||
float min_bound = -127.0) { // [num_rows, moe_topk]
|
||||
float min_bound = -127.0) {
|
||||
const int src_token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int vec_size = sizeof(int4) / sizeof(T);
|
||||
@@ -330,10 +390,17 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
if (up_gate_proj_in_scale) {
|
||||
for (int i = 0; i < vec_size; i++) {
|
||||
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||
if (RoundType == 0) {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||
if constexpr (std::is_same<OutT, int8_t>::value) {
|
||||
// w4aint8
|
||||
if (RoundType == 0) {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||
} else {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(round(quant_value), min_bound, max_bound));
|
||||
}
|
||||
} else {
|
||||
res_vec[i] = static_cast<OutT>(round(quant_value));
|
||||
// w4afp8
|
||||
float value = ClipFunc<float>(quant_value, min_bound, max_bound);
|
||||
res_vec[i] = static_cast<OutT>(value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -373,6 +440,10 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
|
||||
typedef typename traits_fp8::DataType DataType_fp8;
|
||||
typedef typename traits_fp8::data_t data_t_fp8;
|
||||
|
||||
auto stream = input.stream();
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
@@ -420,6 +491,50 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
@@ -493,7 +608,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
|
||||
auto permute_input = GetEmptyTensor(
|
||||
{token_nums_this_rank, hidden_size},
|
||||
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : input_type,
|
||||
moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type,
|
||||
place);
|
||||
auto num_experts_per_rank_tensor = GetEmptyTensor(
|
||||
{num_experts_per_rank},
|
||||
@@ -743,8 +858,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
||||
auto place = input.place();
|
||||
// const int gridx = min(132 * 8, num_rows);
|
||||
const int gridx = 132 * 8;
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>(
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
@@ -765,102 +880,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 9) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 9><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 64) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 64><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 128) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else {
|
||||
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
|
||||
}
|
||||
);)
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ struct nv_type_traits<int8_t> {
|
||||
constexpr int kLogN = 7; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \
|
||||
}
|
||||
|
||||
#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \
|
||||
@@ -108,7 +108,7 @@ struct nv_type_traits<int8_t> {
|
||||
constexpr int VEC_SIZE = 1; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \
|
||||
PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", vec_size)); \
|
||||
}
|
||||
|
||||
#define DISPATCH_logN(logN, kLogN, ...) \
|
||||
@@ -605,26 +605,6 @@ void moe_fast_hardamard_kernel(const T *x,
|
||||
exchange_smem_pre<kNChunks, kChunksPerSmemSize, VecSize, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
|
||||
}
|
||||
if constexpr (kNChunks > 1) {
|
||||
// T x_vals_transposed[VecSize][kNChunks] = {init_value};
|
||||
// #pragma unroll
|
||||
// for (int c = 0; c < kNChunks; ++c) {
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
|
||||
// }
|
||||
// if constexpr (kNChunks == 28) {
|
||||
// hadamard_mult_thread_chunk_28<VecSize>(x_vals_transposed);
|
||||
// } else if constexpr (kNChunks == 36) {
|
||||
// hadamard_mult_thread_chunk_36<VecSize>(x_vals_transposed);
|
||||
// } else {
|
||||
// constexpr int kLogNChunks = cilog2(kNChunks);
|
||||
// static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
|
||||
// hadamard_mult_thread<kLogNChunks, VecSize>(x_vals_transposed);
|
||||
// }
|
||||
// #pragma unroll
|
||||
// for (int c = 0; c < kNChunks; ++c) {
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
|
||||
// }
|
||||
if constexpr (kNChunks == 28) {
|
||||
hadamard_mult_thread_28_transpose<T, VecSize>(x_vals);
|
||||
} else if constexpr (kNChunks == 36) {
|
||||
|
||||
@@ -72,6 +72,287 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)
|
||||
return u;
|
||||
}
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
template <template <typename> class ReductionOp, typename T, int block_size>
|
||||
__inline__ __device__ T BlockAllReduce(T val) {
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
|
||||
if (threadIdx.x == 0) {
|
||||
result_broadcast = result;
|
||||
}
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
|
||||
const float scale,
|
||||
const float max_bound,
|
||||
const float min_bound) {
|
||||
float quant_value = max_bound * scale * static_cast<float>(input);
|
||||
return static_cast<OutType>(ClipFunc<float>(quant_value, min_bound, max_bound));
|
||||
}
|
||||
|
||||
template <typename T, typename OutT, int VecSize, int Kthread>
|
||||
__global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs,
|
||||
const int64_t* expert_idx_per_token,
|
||||
const float* quant_scales,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
OutT* out) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadOutT = AlignedVector<OutT, VecSize>;
|
||||
LoadT input_vec;
|
||||
LoadOutT output_vec;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
|
||||
const auto expert_id = token_idx / num_max_tokens_per_expert;
|
||||
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
|
||||
token_idx += num_iters_to_next_expert * gridDim.x;
|
||||
continue;
|
||||
}
|
||||
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||
float quant_scale = quant_scales[expert_idx];
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
|
||||
thread_row_sum += static_cast<float>(output_vec[i]);
|
||||
}
|
||||
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT, int VecSize, int Kthread>
|
||||
__global__ void quantize_moe_input_kernel(const T* permuted_inputs,
|
||||
const int64_t* expert_idx_per_token,
|
||||
const float* quant_scales,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
OutT* out) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadOutT = AlignedVector<OutT, VecSize>;
|
||||
LoadT input_vec;
|
||||
LoadOutT output_vec;
|
||||
using vec_t = typename BytesToType<sizeof(OutT) * VecSize>::Type;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||
float quant_scale = quant_scales[expert_idx];
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
output_vec[i] = QuantHelperFunc<T, OutT>(input_vec[i], quant_scale, quant_max_bound, quant_min_bound);
|
||||
thread_row_sum += static_cast<float>(output_vec[i]);
|
||||
}
|
||||
*(reinterpret_cast<vec_t*>(&out[offset])) = *(reinterpret_cast<const vec_t*>(&output_vec));
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void quantize_moe_input(
|
||||
const T* permuted_inputs,
|
||||
const int64_t* expert_idx_per_token,
|
||||
const float* quant_scales,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
OutT* out,
|
||||
cudaStream_t stream) {
|
||||
constexpr int VecSize = 16 / sizeof(T);
|
||||
constexpr int threads_per_block = 128;
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
assert(dim % VecSize == 0);
|
||||
auto kernel = used_in_ep_low_latency ? masked_quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block> : quantize_moe_input_kernel<T, OutT, VecSize, threads_per_block>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, threads_per_block, 0);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
kernel<<<grid, threads_per_block, 0, stream>>>(
|
||||
permuted_inputs,
|
||||
expert_idx_per_token,
|
||||
quant_scales,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
token_num,
|
||||
dim,
|
||||
permuted_input_row_sum,
|
||||
recv_expert_count,
|
||||
num_max_tokens_per_expert,
|
||||
out);
|
||||
}
|
||||
|
||||
template <typename T, int VecSize, int Kthread>
|
||||
__global__ void masked_compute_row_sum_kernel(
|
||||
const T* permuted_inputs,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT input_vec;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert;
|
||||
const auto expert_id = token_idx / num_max_tokens_per_expert;
|
||||
if (token_idx_in_expert >= recv_expert_count[expert_id]) {
|
||||
auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert;
|
||||
auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x;
|
||||
token_idx += num_iters_to_next_expert * gridDim.x;
|
||||
continue;
|
||||
}
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
thread_row_sum += static_cast<float>(input_vec[i]);
|
||||
}
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize, int Kthread>
|
||||
__global__ void compute_row_sum_kernel(
|
||||
const T* permuted_inputs,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT input_vec;
|
||||
float scale_factor = -7.0f / 512.0f;
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) {
|
||||
float thread_row_sum = 0.0f;
|
||||
for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
thread_row_sum += static_cast<float>(input_vec[i]);
|
||||
}
|
||||
}
|
||||
float block_row_sum = BlockAllReduce<SumOp, float, Kthread>(thread_row_sum);
|
||||
permuted_input_row_sum[token_idx] = block_row_sum * scale_factor;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_row_sum(
|
||||
const T* permuted_inputs,
|
||||
const int64_t token_num,
|
||||
const int64_t dim,
|
||||
float* permuted_input_row_sum,
|
||||
const int64_t* recv_expert_count,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
cudaStream_t stream) {
|
||||
constexpr int VecSize = 16 / sizeof(T);
|
||||
constexpr int threads_per_block = 128;
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
assert(dim % VecSize == 0);
|
||||
auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel<T, VecSize, threads_per_block> : compute_row_sum_kernel<T, VecSize, threads_per_block>;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, threads_per_block, 0);
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
dim3 grid;
|
||||
grid.x = min(static_cast<int64_t>(num_blocks_per_wave), token_num);
|
||||
kernel<<<grid, threads_per_block, 0, stream>>>(
|
||||
permuted_inputs,
|
||||
token_num,
|
||||
dim,
|
||||
permuted_input_row_sum,
|
||||
recv_expert_count,
|
||||
num_max_tokens_per_expert);
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing
|
||||
// the output in the softmax kernel when we extend this module to support
|
||||
@@ -150,8 +431,61 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
__launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
@@ -208,60 +542,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
@@ -284,6 +565,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
T weight_sum = static_cast<T>(0);
|
||||
T* row_outputs = nullptr;
|
||||
|
||||
if constexpr (NormWeights){
|
||||
extern __shared__ char smem[];
|
||||
row_outputs = reinterpret_cast<T*>(smem);
|
||||
}
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
@@ -296,7 +584,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
@@ -310,15 +598,31 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
if constexpr (NormWeights){
|
||||
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
else{
|
||||
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if constexpr (NormWeights){
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
if (threadIdx.x < k) {
|
||||
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
@@ -356,165 +660,6 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
T val = T(threadDataExp * normalizing_factor);
|
||||
|
||||
// top_k
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
cub_kvp inp_kvp;
|
||||
int expert = threadIdx.x;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
indices[cur_idx] = result_kvp.key;
|
||||
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (block_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
T weight_sum = static_cast<T>(0);
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x < k) {
|
||||
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
// softmax
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||
|
||||
cub::Sum sum;
|
||||
|
||||
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float threadDataSub = threadData - float_max;
|
||||
float threadDataExp = exp(threadDataSub);
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
@@ -532,8 +677,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
T weight_sum = static_cast<T>(0);
|
||||
extern __shared__ char smem[];
|
||||
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||
T* row_outputs = nullptr;
|
||||
if constexpr (NormWeights){
|
||||
extern __shared__ char smem[];
|
||||
row_outputs = reinterpret_cast<T*>(smem);
|
||||
}
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
@@ -560,22 +708,28 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
|
||||
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
|
||||
indices[cur_idx] = result_kvp.key;
|
||||
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||
|
||||
if constexpr (NormWeights) {
|
||||
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
else {
|
||||
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if constexpr (NormWeights) {
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x < k) {
|
||||
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
if (threadIdx.x < k) {
|
||||
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -697,9 +851,11 @@ template <typename T,
|
||||
int NUM_EXPERTS,
|
||||
int WARPS_PER_CTA,
|
||||
int BYTES_PER_LDG,
|
||||
bool Norm_Weights = false,
|
||||
typename IdxT = int>
|
||||
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
void topk_gating_softmax(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
const int64_t num_rows,
|
||||
IdxT* indices,
|
||||
@@ -755,6 +911,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
const int thread_row_in_cta = thread_row - cta_base_row;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows) return;
|
||||
@@ -770,6 +927,9 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
T weight_sum = static_cast<T>(0);
|
||||
extern __shared__ T row_output[];
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the
|
||||
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
|
||||
// up to 16.
|
||||
@@ -838,7 +998,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + bias[first_elt_read_by_thread + ii] : row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
|
||||
@@ -887,12 +1047,20 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
T final_val = bias ? T(max_val) - bias[expert] : T(max_val);
|
||||
if (thread_group_idx == 0) {
|
||||
// The lead thread from each sub-group will write out the final results to
|
||||
// global memory. (This will be a single) thread per row of the
|
||||
// input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = T(max_val);
|
||||
if constexpr (Norm_Weights) {
|
||||
const int idx_in_cta = k * thread_row_in_cta + k_idx;
|
||||
row_output[idx_in_cta] = final_val;
|
||||
weight_sum += final_val;
|
||||
}
|
||||
else {
|
||||
output[idx] = final_val;
|
||||
}
|
||||
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
@@ -915,6 +1083,16 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (Norm_Weights) {
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
if (thread_group_idx == 0) {
|
||||
const int idx = k * thread_row + k_idx;
|
||||
const int idx_in_cta = k * thread_row_in_cta + k_idx;
|
||||
output[idx] = row_output[idx_in_cta] / weight_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
@@ -934,8 +1112,9 @@ struct TopkConstants {
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB, typename IdxT = int>
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
|
||||
void topk_gating_softmax_launcher_helper(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_row,
|
||||
@@ -953,9 +1132,10 @@ void topk_gating_softmax_launcher_helper(const T* input,
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
|
||||
<<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, output, num_rows, indices, source_row, k);
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
|
||||
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
|
||||
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
|
||||
input, bias, output, num_rows, indices, source_row, k);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
@@ -986,7 +1166,7 @@ static void run(const T* input,
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
input, gating_correction_bias, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
@@ -1015,7 +1195,7 @@ static void run(const T* input,
|
||||
group_experts,
|
||||
softmax_num_rows);
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
group_moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
@@ -1116,6 +1296,18 @@ __global__ void initialize_moe_routing_kernel(
|
||||
dest_vec[j] = static_cast<int8_t>(round(quant_value));
|
||||
}
|
||||
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
|
||||
} else if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
|
||||
using StoreT = AlignedVector<OutT, VecSize>;
|
||||
StoreT dest_vec;
|
||||
const float max_bound = 448.f;
|
||||
const float min_bound = -448.f;
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
float quant_value = max_bound * scale * static_cast<float>(src_vec[j]);
|
||||
quant_value = quant_value > max_bound ? max_bound : quant_value;
|
||||
quant_value = quant_value < min_bound ? min_bound : quant_value;
|
||||
dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value);
|
||||
}
|
||||
Store<phi::dtype::float8_e4m3fn, VecSize>(dest_vec, &dest_row_ptr[tid]);
|
||||
} else {
|
||||
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user